remove gfpgan dep
This commit is contained in:
parent
ffdf5e06e1
commit
60b1411d6b
@ -8,4 +8,3 @@ def install(package):
|
||||
|
||||
def install_plugins_package():
|
||||
install("rembg")
|
||||
install("gfpgan")
|
||||
|
172
iopaint/plugins/basicsr/img_util.py
Normal file
172
iopaint/plugins/basicsr/img_util.py
Normal file
@ -0,0 +1,172 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""Convert torch Tensors into image numpy arrays.
|
||||
|
||||
After clamping to [min, max], values will be normalized to [0, 1].
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list[Tensor]): Accept shapes:
|
||||
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
||||
2) 3D Tensor of shape (3/1 x H x W);
|
||||
3) 2D Tensor of shape (H x W).
|
||||
Tensor channel should be in RGB order.
|
||||
rgb2bgr (bool): Whether to change rgb to bgr.
|
||||
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
||||
to uint8 type with range [0, 255]; otherwise, float type with
|
||||
range [0, 1]. Default: ``np.uint8``.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
|
||||
Returns:
|
||||
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
||||
shape (H x W). The channel order is BGR.
|
||||
"""
|
||||
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if torch.is_tensor(tensor):
|
||||
tensor = [tensor]
|
||||
result = []
|
||||
for _tensor in tensor:
|
||||
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
|
||||
n_dim = _tensor.dim()
|
||||
if n_dim == 4:
|
||||
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 3:
|
||||
img_np = _tensor.numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image
|
||||
img_np = np.squeeze(img_np, axis=2)
|
||||
else:
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 2:
|
||||
img_np = _tensor.numpy()
|
||||
else:
|
||||
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
||||
if out_type == np.uint8:
|
||||
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||
img_np = (img_np * 255.0).round()
|
||||
img_np = img_np.astype(out_type)
|
||||
result.append(img_np)
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
return result
|
||||
|
||||
|
||||
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
||||
"""This implementation is slightly faster than tensor2img.
|
||||
It now only supports torch tensor with shape (1, c, h, w).
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
||||
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
"""
|
||||
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
||||
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
||||
output = output.type(torch.uint8).cpu().numpy()
|
||||
if rgb2bgr:
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
|
||||
def imfrombytes(content, flag='color', float32=False):
|
||||
"""Read an image from bytes.
|
||||
|
||||
Args:
|
||||
content (bytes): Image bytes got from files or other streams.
|
||||
flag (str): Flags specifying the color type of a loaded image,
|
||||
candidates are `color`, `grayscale` and `unchanged`.
|
||||
float32 (bool): Whether to change to float32., If True, will also norm
|
||||
to [0, 1]. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: Loaded image array.
|
||||
"""
|
||||
img_np = np.frombuffer(content, np.uint8)
|
||||
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
||||
img = cv2.imdecode(img_np, imread_flags[flag])
|
||||
if float32:
|
||||
img = img.astype(np.float32) / 255.
|
||||
return img
|
||||
|
||||
|
||||
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||
"""Write image to file.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image array to be written.
|
||||
file_path (str): Image file path.
|
||||
params (None or list): Same as opencv's :func:`imwrite` interface.
|
||||
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
||||
whether to create it automatically.
|
||||
|
||||
Returns:
|
||||
bool: Successful or not.
|
||||
"""
|
||||
if auto_mkdir:
|
||||
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
ok = cv2.imwrite(file_path, img, params)
|
||||
if not ok:
|
||||
raise IOError('Failed in writing images.')
|
||||
|
||||
|
||||
def crop_border(imgs, crop_border):
|
||||
"""Crop borders of images.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
||||
crop_border (int): Crop border for each end of height and weight.
|
||||
|
||||
Returns:
|
||||
list[ndarray]: Cropped images.
|
||||
"""
|
||||
if crop_border == 0:
|
||||
return imgs
|
||||
else:
|
||||
if isinstance(imgs, list):
|
||||
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
||||
else:
|
||||
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
135
iopaint/plugins/facexlib/.gitignore
vendored
Normal file
135
iopaint/plugins/facexlib/.gitignore
vendored
Normal file
@ -0,0 +1,135 @@
|
||||
.vscode
|
||||
*.pth
|
||||
*.png
|
||||
*.jpg
|
||||
version.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
3
iopaint/plugins/facexlib/__init__.py
Normal file
3
iopaint/plugins/facexlib/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# flake8: noqa
|
||||
from .detection import *
|
||||
from .utils import *
|
31
iopaint/plugins/facexlib/detection/__init__.py
Normal file
31
iopaint/plugins/facexlib/detection/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from ..utils import load_file_from_url
|
||||
from .retinaface import RetinaFace
|
||||
|
||||
|
||||
def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None):
|
||||
if model_name == 'retinaface_resnet50':
|
||||
model = RetinaFace(network_name='resnet50', half=half, device=device)
|
||||
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
|
||||
elif model_name == 'retinaface_mobile0.25':
|
||||
model = RetinaFace(network_name='mobile0.25', half=half, device=device)
|
||||
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
|
||||
else:
|
||||
raise NotImplementedError(f'{model_name} is not implemented.')
|
||||
|
||||
model_path = load_file_from_url(
|
||||
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
||||
|
||||
# TODO: clean pretrained model
|
||||
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
|
||||
# remove unnecessary 'module.'
|
||||
for k, v in deepcopy(load_net).items():
|
||||
if k.startswith('module.'):
|
||||
load_net[k[7:]] = v
|
||||
load_net.pop(k)
|
||||
model.load_state_dict(load_net, strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
return model
|
219
iopaint/plugins/facexlib/detection/align_trans.py
Normal file
219
iopaint/plugins/facexlib/detection/align_trans.py
Normal file
@ -0,0 +1,219 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from .matlab_cp2tform import get_similarity_transform_for_cv2
|
||||
|
||||
# reference facial points, a list of coordinates (x,y)
|
||||
REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
|
||||
[33.54930115, 92.3655014], [62.72990036, 92.20410156]]
|
||||
|
||||
DEFAULT_CROP_SIZE = (96, 112)
|
||||
|
||||
|
||||
class FaceWarpException(Exception):
|
||||
|
||||
def __str__(self):
|
||||
return 'In File {}:{}'.format(__file__, super.__str__(self))
|
||||
|
||||
|
||||
def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
get reference 5 key points according to crop settings:
|
||||
0. Set default crop_size:
|
||||
if default_square:
|
||||
crop_size = (112, 112)
|
||||
else:
|
||||
crop_size = (96, 112)
|
||||
1. Pad the crop_size by inner_padding_factor in each side;
|
||||
2. Resize crop_size into (output_size - outer_padding*2),
|
||||
pad into output_size with outer_padding;
|
||||
3. Output reference_5point;
|
||||
Parameters:
|
||||
----------
|
||||
@output_size: (w, h) or None
|
||||
size of aligned face image
|
||||
@inner_padding_factor: (w_factor, h_factor)
|
||||
padding factor for inner (w, h)
|
||||
@outer_padding: (w_pad, h_pad)
|
||||
each row is a pair of coordinates (x, y)
|
||||
@default_square: True or False
|
||||
if True:
|
||||
default crop_size = (112, 112)
|
||||
else:
|
||||
default crop_size = (96, 112);
|
||||
!!! make sure, if output_size is not None:
|
||||
(output_size - outer_padding)
|
||||
= some_scale * (default crop_size * (1.0 +
|
||||
inner_padding_factor))
|
||||
Returns:
|
||||
----------
|
||||
@reference_5point: 5x2 np.array
|
||||
each row is a pair of transformed coordinates (x, y)
|
||||
"""
|
||||
|
||||
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
||||
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
||||
|
||||
# 0) make the inner region a square
|
||||
if default_square:
|
||||
size_diff = max(tmp_crop_size) - tmp_crop_size
|
||||
tmp_5pts += size_diff / 2
|
||||
tmp_crop_size += size_diff
|
||||
|
||||
if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
|
||||
|
||||
return tmp_5pts
|
||||
|
||||
if (inner_padding_factor == 0 and outer_padding == (0, 0)):
|
||||
if output_size is None:
|
||||
return tmp_5pts
|
||||
else:
|
||||
raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
||||
|
||||
# check output size
|
||||
if not (0 <= inner_padding_factor <= 1.0):
|
||||
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
||||
|
||||
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
|
||||
output_size = tmp_crop_size * \
|
||||
(1 + inner_padding_factor * 2).astype(np.int32)
|
||||
output_size += np.array(outer_padding)
|
||||
if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
|
||||
raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
|
||||
|
||||
# 1) pad the inner region according inner_padding_factor
|
||||
if inner_padding_factor > 0:
|
||||
size_diff = tmp_crop_size * inner_padding_factor * 2
|
||||
tmp_5pts += size_diff / 2
|
||||
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
||||
|
||||
# 2) resize the padded inner region
|
||||
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
||||
|
||||
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
||||
raise FaceWarpException('Must have (output_size - outer_padding)'
|
||||
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
||||
|
||||
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
||||
tmp_5pts = tmp_5pts * scale_factor
|
||||
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
||||
# tmp_5pts = tmp_5pts + size_diff / 2
|
||||
tmp_crop_size = size_bf_outer_pad
|
||||
|
||||
# 3) add outer_padding to make output_size
|
||||
reference_5point = tmp_5pts + np.array(outer_padding)
|
||||
tmp_crop_size = output_size
|
||||
|
||||
return reference_5point
|
||||
|
||||
|
||||
def get_affine_transform_matrix(src_pts, dst_pts):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
get affine transform matrix 'tfm' from src_pts to dst_pts
|
||||
Parameters:
|
||||
----------
|
||||
@src_pts: Kx2 np.array
|
||||
source points matrix, each row is a pair of coordinates (x, y)
|
||||
@dst_pts: Kx2 np.array
|
||||
destination points matrix, each row is a pair of coordinates (x, y)
|
||||
Returns:
|
||||
----------
|
||||
@tfm: 2x3 np.array
|
||||
transform matrix from src_pts to dst_pts
|
||||
"""
|
||||
|
||||
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
||||
n_pts = src_pts.shape[0]
|
||||
ones = np.ones((n_pts, 1), src_pts.dtype)
|
||||
src_pts_ = np.hstack([src_pts, ones])
|
||||
dst_pts_ = np.hstack([dst_pts, ones])
|
||||
|
||||
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
||||
|
||||
if rank == 3:
|
||||
tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
|
||||
elif rank == 2:
|
||||
tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
|
||||
|
||||
return tfm
|
||||
|
||||
|
||||
def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
apply affine transform 'trans' to uv
|
||||
Parameters:
|
||||
----------
|
||||
@src_img: 3x3 np.array
|
||||
input image
|
||||
@facial_pts: could be
|
||||
1)a list of K coordinates (x,y)
|
||||
or
|
||||
2) Kx2 or 2xK np.array
|
||||
each row or col is a pair of coordinates (x, y)
|
||||
@reference_pts: could be
|
||||
1) a list of K coordinates (x,y)
|
||||
or
|
||||
2) Kx2 or 2xK np.array
|
||||
each row or col is a pair of coordinates (x, y)
|
||||
or
|
||||
3) None
|
||||
if None, use default reference facial points
|
||||
@crop_size: (w, h)
|
||||
output face image size
|
||||
@align_type: transform type, could be one of
|
||||
1) 'similarity': use similarity transform
|
||||
2) 'cv2_affine': use the first 3 points to do affine transform,
|
||||
by calling cv2.getAffineTransform()
|
||||
3) 'affine': use all points to do affine transform
|
||||
Returns:
|
||||
----------
|
||||
@face_img: output face image with size (w, h) = @crop_size
|
||||
"""
|
||||
|
||||
if reference_pts is None:
|
||||
if crop_size[0] == 96 and crop_size[1] == 112:
|
||||
reference_pts = REFERENCE_FACIAL_POINTS
|
||||
else:
|
||||
default_square = False
|
||||
inner_padding_factor = 0
|
||||
outer_padding = (0, 0)
|
||||
output_size = crop_size
|
||||
|
||||
reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
|
||||
default_square)
|
||||
|
||||
ref_pts = np.float32(reference_pts)
|
||||
ref_pts_shp = ref_pts.shape
|
||||
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
||||
raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
|
||||
|
||||
if ref_pts_shp[0] == 2:
|
||||
ref_pts = ref_pts.T
|
||||
|
||||
src_pts = np.float32(facial_pts)
|
||||
src_pts_shp = src_pts.shape
|
||||
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
||||
raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
|
||||
|
||||
if src_pts_shp[0] == 2:
|
||||
src_pts = src_pts.T
|
||||
|
||||
if src_pts.shape != ref_pts.shape:
|
||||
raise FaceWarpException('facial_pts and reference_pts must have the same shape')
|
||||
|
||||
if align_type == 'cv2_affine':
|
||||
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
||||
elif align_type == 'affine':
|
||||
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
||||
else:
|
||||
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
|
||||
|
||||
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
|
||||
|
||||
return face_img
|
317
iopaint/plugins/facexlib/detection/matlab_cp2tform.py
Normal file
317
iopaint/plugins/facexlib/detection/matlab_cp2tform.py
Normal file
@ -0,0 +1,317 @@
|
||||
import numpy as np
|
||||
from numpy.linalg import inv, lstsq
|
||||
from numpy.linalg import matrix_rank as rank
|
||||
from numpy.linalg import norm
|
||||
|
||||
|
||||
class MatlabCp2tormException(Exception):
|
||||
|
||||
def __str__(self):
|
||||
return 'In File {}:{}'.format(__file__, super.__str__(self))
|
||||
|
||||
|
||||
def tformfwd(trans, uv):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
apply affine transform 'trans' to uv
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
@trans: 3x3 np.array
|
||||
transform matrix
|
||||
@uv: Kx2 np.array
|
||||
each row is a pair of coordinates (x, y)
|
||||
|
||||
Returns:
|
||||
----------
|
||||
@xy: Kx2 np.array
|
||||
each row is a pair of transformed coordinates (x, y)
|
||||
"""
|
||||
uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
|
||||
xy = np.dot(uv, trans)
|
||||
xy = xy[:, 0:-1]
|
||||
return xy
|
||||
|
||||
|
||||
def tforminv(trans, uv):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
apply the inverse of affine transform 'trans' to uv
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
@trans: 3x3 np.array
|
||||
transform matrix
|
||||
@uv: Kx2 np.array
|
||||
each row is a pair of coordinates (x, y)
|
||||
|
||||
Returns:
|
||||
----------
|
||||
@xy: Kx2 np.array
|
||||
each row is a pair of inverse-transformed coordinates (x, y)
|
||||
"""
|
||||
Tinv = inv(trans)
|
||||
xy = tformfwd(Tinv, uv)
|
||||
return xy
|
||||
|
||||
|
||||
def findNonreflectiveSimilarity(uv, xy, options=None):
|
||||
options = {'K': 2}
|
||||
|
||||
K = options['K']
|
||||
M = xy.shape[0]
|
||||
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
||||
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
||||
|
||||
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
|
||||
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
|
||||
X = np.vstack((tmp1, tmp2))
|
||||
|
||||
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
||||
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
||||
U = np.vstack((u, v))
|
||||
|
||||
# We know that X * r = U
|
||||
if rank(X) >= 2 * K:
|
||||
r, _, _, _ = lstsq(X, U, rcond=-1)
|
||||
r = np.squeeze(r)
|
||||
else:
|
||||
raise Exception('cp2tform:twoUniquePointsReq')
|
||||
sc = r[0]
|
||||
ss = r[1]
|
||||
tx = r[2]
|
||||
ty = r[3]
|
||||
|
||||
Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
|
||||
T = inv(Tinv)
|
||||
T[:, 2] = np.array([0, 0, 1])
|
||||
|
||||
return T, Tinv
|
||||
|
||||
|
||||
def findSimilarity(uv, xy, options=None):
|
||||
options = {'K': 2}
|
||||
|
||||
# uv = np.array(uv)
|
||||
# xy = np.array(xy)
|
||||
|
||||
# Solve for trans1
|
||||
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
|
||||
|
||||
# Solve for trans2
|
||||
|
||||
# manually reflect the xy data across the Y-axis
|
||||
xyR = xy
|
||||
xyR[:, 0] = -1 * xyR[:, 0]
|
||||
|
||||
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
|
||||
|
||||
# manually reflect the tform to undo the reflection done on xyR
|
||||
TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
||||
|
||||
trans2 = np.dot(trans2r, TreflectY)
|
||||
|
||||
# Figure out if trans1 or trans2 is better
|
||||
xy1 = tformfwd(trans1, uv)
|
||||
norm1 = norm(xy1 - xy)
|
||||
|
||||
xy2 = tformfwd(trans2, uv)
|
||||
norm2 = norm(xy2 - xy)
|
||||
|
||||
if norm1 <= norm2:
|
||||
return trans1, trans1_inv
|
||||
else:
|
||||
trans2_inv = inv(trans2)
|
||||
return trans2, trans2_inv
|
||||
|
||||
|
||||
def get_similarity_transform(src_pts, dst_pts, reflective=True):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
Find Similarity Transform Matrix 'trans':
|
||||
u = src_pts[:, 0]
|
||||
v = src_pts[:, 1]
|
||||
x = dst_pts[:, 0]
|
||||
y = dst_pts[:, 1]
|
||||
[x, y, 1] = [u, v, 1] * trans
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
@src_pts: Kx2 np.array
|
||||
source points, each row is a pair of coordinates (x, y)
|
||||
@dst_pts: Kx2 np.array
|
||||
destination points, each row is a pair of transformed
|
||||
coordinates (x, y)
|
||||
@reflective: True or False
|
||||
if True:
|
||||
use reflective similarity transform
|
||||
else:
|
||||
use non-reflective similarity transform
|
||||
|
||||
Returns:
|
||||
----------
|
||||
@trans: 3x3 np.array
|
||||
transform matrix from uv to xy
|
||||
trans_inv: 3x3 np.array
|
||||
inverse of trans, transform matrix from xy to uv
|
||||
"""
|
||||
|
||||
if reflective:
|
||||
trans, trans_inv = findSimilarity(src_pts, dst_pts)
|
||||
else:
|
||||
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
|
||||
|
||||
return trans, trans_inv
|
||||
|
||||
|
||||
def cvt_tform_mat_for_cv2(trans):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
|
||||
directly used by cv2.warpAffine():
|
||||
u = src_pts[:, 0]
|
||||
v = src_pts[:, 1]
|
||||
x = dst_pts[:, 0]
|
||||
y = dst_pts[:, 1]
|
||||
[x, y].T = cv_trans * [u, v, 1].T
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
@trans: 3x3 np.array
|
||||
transform matrix from uv to xy
|
||||
|
||||
Returns:
|
||||
----------
|
||||
@cv2_trans: 2x3 np.array
|
||||
transform matrix from src_pts to dst_pts, could be directly used
|
||||
for cv2.warpAffine()
|
||||
"""
|
||||
cv2_trans = trans[:, 0:2].T
|
||||
|
||||
return cv2_trans
|
||||
|
||||
|
||||
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
|
||||
"""
|
||||
Function:
|
||||
----------
|
||||
Find Similarity Transform Matrix 'cv2_trans' which could be
|
||||
directly used by cv2.warpAffine():
|
||||
u = src_pts[:, 0]
|
||||
v = src_pts[:, 1]
|
||||
x = dst_pts[:, 0]
|
||||
y = dst_pts[:, 1]
|
||||
[x, y].T = cv_trans * [u, v, 1].T
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
@src_pts: Kx2 np.array
|
||||
source points, each row is a pair of coordinates (x, y)
|
||||
@dst_pts: Kx2 np.array
|
||||
destination points, each row is a pair of transformed
|
||||
coordinates (x, y)
|
||||
reflective: True or False
|
||||
if True:
|
||||
use reflective similarity transform
|
||||
else:
|
||||
use non-reflective similarity transform
|
||||
|
||||
Returns:
|
||||
----------
|
||||
@cv2_trans: 2x3 np.array
|
||||
transform matrix from src_pts to dst_pts, could be directly used
|
||||
for cv2.warpAffine()
|
||||
"""
|
||||
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
|
||||
cv2_trans = cvt_tform_mat_for_cv2(trans)
|
||||
|
||||
return cv2_trans
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
u = [0, 6, -2]
|
||||
v = [0, 3, 5]
|
||||
x = [-1, 0, 4]
|
||||
y = [-1, -10, 4]
|
||||
|
||||
# In Matlab, run:
|
||||
#
|
||||
# uv = [u'; v'];
|
||||
# xy = [x'; y'];
|
||||
# tform_sim=cp2tform(uv,xy,'similarity');
|
||||
#
|
||||
# trans = tform_sim.tdata.T
|
||||
# ans =
|
||||
# -0.0764 -1.6190 0
|
||||
# 1.6190 -0.0764 0
|
||||
# -3.2156 0.0290 1.0000
|
||||
# trans_inv = tform_sim.tdata.Tinv
|
||||
# ans =
|
||||
#
|
||||
# -0.0291 0.6163 0
|
||||
# -0.6163 -0.0291 0
|
||||
# -0.0756 1.9826 1.0000
|
||||
# xy_m=tformfwd(tform_sim, u,v)
|
||||
#
|
||||
# xy_m =
|
||||
#
|
||||
# -3.2156 0.0290
|
||||
# 1.1833 -9.9143
|
||||
# 5.0323 2.8853
|
||||
# uv_m=tforminv(tform_sim, x,y)
|
||||
#
|
||||
# uv_m =
|
||||
#
|
||||
# 0.5698 1.3953
|
||||
# 6.0872 2.2733
|
||||
# -2.6570 4.3314
|
||||
"""
|
||||
u = [0, 6, -2]
|
||||
v = [0, 3, 5]
|
||||
x = [-1, 0, 4]
|
||||
y = [-1, -10, 4]
|
||||
|
||||
uv = np.array((u, v)).T
|
||||
xy = np.array((x, y)).T
|
||||
|
||||
print('\n--->uv:')
|
||||
print(uv)
|
||||
print('\n--->xy:')
|
||||
print(xy)
|
||||
|
||||
trans, trans_inv = get_similarity_transform(uv, xy)
|
||||
|
||||
print('\n--->trans matrix:')
|
||||
print(trans)
|
||||
|
||||
print('\n--->trans_inv matrix:')
|
||||
print(trans_inv)
|
||||
|
||||
print('\n---> apply transform to uv')
|
||||
print('\nxy_m = uv_augmented * trans')
|
||||
uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
|
||||
xy_m = np.dot(uv_aug, trans)
|
||||
print(xy_m)
|
||||
|
||||
print('\nxy_m = tformfwd(trans, uv)')
|
||||
xy_m = tformfwd(trans, uv)
|
||||
print(xy_m)
|
||||
|
||||
print('\n---> apply inverse transform to xy')
|
||||
print('\nuv_m = xy_augmented * trans_inv')
|
||||
xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
|
||||
uv_m = np.dot(xy_aug, trans_inv)
|
||||
print(uv_m)
|
||||
|
||||
print('\nuv_m = tformfwd(trans_inv, xy)')
|
||||
uv_m = tformfwd(trans_inv, xy)
|
||||
print(uv_m)
|
||||
|
||||
uv_m = tforminv(trans, xy)
|
||||
print('\nuv_m = tforminv(trans, xy)')
|
||||
print(uv_m)
|
419
iopaint/plugins/facexlib/detection/retinaface.py
Normal file
419
iopaint/plugins/facexlib/detection/retinaface.py
Normal file
@ -0,0 +1,419 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
|
||||
|
||||
from .align_trans import get_reference_facial_points, warp_and_crop_face
|
||||
from .retinaface_net import (
|
||||
FPN,
|
||||
SSH,
|
||||
MobileNetV1,
|
||||
make_bbox_head,
|
||||
make_class_head,
|
||||
make_landmark_head,
|
||||
)
|
||||
from .retinaface_utils import (
|
||||
PriorBox,
|
||||
batched_decode,
|
||||
batched_decode_landm,
|
||||
decode,
|
||||
decode_landm,
|
||||
py_cpu_nms,
|
||||
)
|
||||
|
||||
|
||||
def generate_config(network_name):
|
||||
cfg_mnet = {
|
||||
"name": "mobilenet0.25",
|
||||
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
||||
"steps": [8, 16, 32],
|
||||
"variance": [0.1, 0.2],
|
||||
"clip": False,
|
||||
"loc_weight": 2.0,
|
||||
"gpu_train": True,
|
||||
"batch_size": 32,
|
||||
"ngpu": 1,
|
||||
"epoch": 250,
|
||||
"decay1": 190,
|
||||
"decay2": 220,
|
||||
"image_size": 640,
|
||||
"return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
|
||||
"in_channel": 32,
|
||||
"out_channel": 64,
|
||||
}
|
||||
|
||||
cfg_re50 = {
|
||||
"name": "Resnet50",
|
||||
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
||||
"steps": [8, 16, 32],
|
||||
"variance": [0.1, 0.2],
|
||||
"clip": False,
|
||||
"loc_weight": 2.0,
|
||||
"gpu_train": True,
|
||||
"batch_size": 24,
|
||||
"ngpu": 4,
|
||||
"epoch": 100,
|
||||
"decay1": 70,
|
||||
"decay2": 90,
|
||||
"image_size": 840,
|
||||
"return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
|
||||
"in_channel": 256,
|
||||
"out_channel": 256,
|
||||
}
|
||||
|
||||
if network_name == "mobile0.25":
|
||||
return cfg_mnet
|
||||
elif network_name == "resnet50":
|
||||
return cfg_re50
|
||||
else:
|
||||
raise NotImplementedError(f"network_name={network_name}")
|
||||
|
||||
|
||||
class RetinaFace(nn.Module):
|
||||
def __init__(self, network_name="resnet50", half=False, phase="test", device=None):
|
||||
self.device = (
|
||||
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if device is None
|
||||
else device
|
||||
)
|
||||
|
||||
super(RetinaFace, self).__init__()
|
||||
self.half_inference = half
|
||||
cfg = generate_config(network_name)
|
||||
self.backbone = cfg["name"]
|
||||
|
||||
self.model_name = f"retinaface_{network_name}"
|
||||
self.cfg = cfg
|
||||
self.phase = phase
|
||||
self.target_size, self.max_size = 1600, 2150
|
||||
self.resize, self.scale, self.scale1 = 1.0, None, None
|
||||
self.mean_tensor = torch.tensor(
|
||||
[[[[104.0]], [[117.0]], [[123.0]]]], device=self.device
|
||||
)
|
||||
self.reference = get_reference_facial_points(default_square=True)
|
||||
# Build network.
|
||||
backbone = None
|
||||
if cfg["name"] == "mobilenet0.25":
|
||||
backbone = MobileNetV1()
|
||||
self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
|
||||
elif cfg["name"] == "Resnet50":
|
||||
import torchvision.models as models
|
||||
|
||||
backbone = models.resnet50(pretrained=False)
|
||||
self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
|
||||
|
||||
in_channels_stage2 = cfg["in_channel"]
|
||||
in_channels_list = [
|
||||
in_channels_stage2 * 2,
|
||||
in_channels_stage2 * 4,
|
||||
in_channels_stage2 * 8,
|
||||
]
|
||||
|
||||
out_channels = cfg["out_channel"]
|
||||
self.fpn = FPN(in_channels_list, out_channels)
|
||||
self.ssh1 = SSH(out_channels, out_channels)
|
||||
self.ssh2 = SSH(out_channels, out_channels)
|
||||
self.ssh3 = SSH(out_channels, out_channels)
|
||||
|
||||
self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"])
|
||||
self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"])
|
||||
self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"])
|
||||
|
||||
self.to(self.device)
|
||||
self.eval()
|
||||
if self.half_inference:
|
||||
self.half()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.body(inputs)
|
||||
|
||||
if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50":
|
||||
out = list(out.values())
|
||||
# FPN
|
||||
fpn = self.fpn(out)
|
||||
|
||||
# SSH
|
||||
feature1 = self.ssh1(fpn[0])
|
||||
feature2 = self.ssh2(fpn[1])
|
||||
feature3 = self.ssh3(fpn[2])
|
||||
features = [feature1, feature2, feature3]
|
||||
|
||||
bbox_regressions = torch.cat(
|
||||
[self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
|
||||
)
|
||||
classifications = torch.cat(
|
||||
[self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
|
||||
)
|
||||
tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
|
||||
ldm_regressions = torch.cat(tmp, dim=1)
|
||||
|
||||
if self.phase == "train":
|
||||
output = (bbox_regressions, classifications, ldm_regressions)
|
||||
else:
|
||||
output = (
|
||||
bbox_regressions,
|
||||
F.softmax(classifications, dim=-1),
|
||||
ldm_regressions,
|
||||
)
|
||||
return output
|
||||
|
||||
def __detect_faces(self, inputs):
|
||||
# get scale
|
||||
height, width = inputs.shape[2:]
|
||||
self.scale = torch.tensor(
|
||||
[width, height, width, height], dtype=torch.float32, device=self.device
|
||||
)
|
||||
tmp = [
|
||||
width,
|
||||
height,
|
||||
width,
|
||||
height,
|
||||
width,
|
||||
height,
|
||||
width,
|
||||
height,
|
||||
width,
|
||||
height,
|
||||
]
|
||||
self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)
|
||||
|
||||
# forawrd
|
||||
inputs = inputs.to(self.device)
|
||||
if self.half_inference:
|
||||
inputs = inputs.half()
|
||||
loc, conf, landmarks = self(inputs)
|
||||
|
||||
# get priorbox
|
||||
priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
|
||||
priors = priorbox.forward().to(self.device)
|
||||
|
||||
return loc, conf, landmarks, priors
|
||||
|
||||
# single image detection
|
||||
def transform(self, image, use_origin_size):
|
||||
# convert to opencv format
|
||||
if isinstance(image, Image.Image):
|
||||
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
image = image.astype(np.float32)
|
||||
|
||||
# testing scale
|
||||
im_size_min = np.min(image.shape[0:2])
|
||||
im_size_max = np.max(image.shape[0:2])
|
||||
resize = float(self.target_size) / float(im_size_min)
|
||||
|
||||
# prevent bigger axis from being more than max_size
|
||||
if np.round(resize * im_size_max) > self.max_size:
|
||||
resize = float(self.max_size) / float(im_size_max)
|
||||
resize = 1 if use_origin_size else resize
|
||||
|
||||
# resize
|
||||
if resize != 1:
|
||||
image = cv2.resize(
|
||||
image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
|
||||
# convert to torch.tensor format
|
||||
# image -= (104, 117, 123)
|
||||
image = image.transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).unsqueeze(0)
|
||||
|
||||
return image, resize
|
||||
|
||||
def detect_faces(
|
||||
self,
|
||||
image,
|
||||
conf_threshold=0.8,
|
||||
nms_threshold=0.4,
|
||||
use_origin_size=True,
|
||||
):
|
||||
image, self.resize = self.transform(image, use_origin_size)
|
||||
image = image.to(self.device)
|
||||
if self.half_inference:
|
||||
image = image.half()
|
||||
image = image - self.mean_tensor
|
||||
|
||||
loc, conf, landmarks, priors = self.__detect_faces(image)
|
||||
|
||||
boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"])
|
||||
boxes = boxes * self.scale / self.resize
|
||||
boxes = boxes.cpu().numpy()
|
||||
|
||||
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
||||
|
||||
landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"])
|
||||
landmarks = landmarks * self.scale1 / self.resize
|
||||
landmarks = landmarks.cpu().numpy()
|
||||
|
||||
# ignore low scores
|
||||
inds = np.where(scores > conf_threshold)[0]
|
||||
boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
|
||||
|
||||
# sort
|
||||
order = scores.argsort()[::-1]
|
||||
boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
|
||||
|
||||
# do NMS
|
||||
bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(
|
||||
np.float32, copy=False
|
||||
)
|
||||
keep = py_cpu_nms(bounding_boxes, nms_threshold)
|
||||
bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
|
||||
# self.t['forward_pass'].toc()
|
||||
# print(self.t['forward_pass'].average_time)
|
||||
# import sys
|
||||
# sys.stdout.flush()
|
||||
return np.concatenate((bounding_boxes, landmarks), axis=1)
|
||||
|
||||
def __align_multi(self, image, boxes, landmarks, limit=None):
|
||||
if len(boxes) < 1:
|
||||
return [], []
|
||||
|
||||
if limit:
|
||||
boxes = boxes[:limit]
|
||||
landmarks = landmarks[:limit]
|
||||
|
||||
faces = []
|
||||
for landmark in landmarks:
|
||||
facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
|
||||
|
||||
warped_face = warp_and_crop_face(
|
||||
np.array(image), facial5points, self.reference, crop_size=(112, 112)
|
||||
)
|
||||
faces.append(warped_face)
|
||||
|
||||
return np.concatenate((boxes, landmarks), axis=1), faces
|
||||
|
||||
def align_multi(self, img, conf_threshold=0.8, limit=None):
|
||||
rlt = self.detect_faces(img, conf_threshold=conf_threshold)
|
||||
boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
|
||||
|
||||
return self.__align_multi(img, boxes, landmarks, limit)
|
||||
|
||||
# batched detection
|
||||
def batched_transform(self, frames, use_origin_size):
|
||||
"""
|
||||
Arguments:
|
||||
frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
|
||||
type=np.float32, BGR format).
|
||||
use_origin_size: whether to use origin size.
|
||||
"""
|
||||
from_PIL = True if isinstance(frames[0], Image.Image) else False
|
||||
|
||||
# convert to opencv format
|
||||
if from_PIL:
|
||||
frames = [
|
||||
cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames
|
||||
]
|
||||
frames = np.asarray(frames, dtype=np.float32)
|
||||
|
||||
# testing scale
|
||||
im_size_min = np.min(frames[0].shape[0:2])
|
||||
im_size_max = np.max(frames[0].shape[0:2])
|
||||
resize = float(self.target_size) / float(im_size_min)
|
||||
|
||||
# prevent bigger axis from being more than max_size
|
||||
if np.round(resize * im_size_max) > self.max_size:
|
||||
resize = float(self.max_size) / float(im_size_max)
|
||||
resize = 1 if use_origin_size else resize
|
||||
|
||||
# resize
|
||||
if resize != 1:
|
||||
if not from_PIL:
|
||||
frames = F.interpolate(frames, scale_factor=resize)
|
||||
else:
|
||||
frames = [
|
||||
cv2.resize(
|
||||
frame,
|
||||
None,
|
||||
None,
|
||||
fx=resize,
|
||||
fy=resize,
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
for frame in frames
|
||||
]
|
||||
|
||||
# convert to torch.tensor format
|
||||
if not from_PIL:
|
||||
frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
|
||||
else:
|
||||
frames = frames.transpose((0, 3, 1, 2))
|
||||
frames = torch.from_numpy(frames)
|
||||
|
||||
return frames, resize
|
||||
|
||||
def batched_detect_faces(
|
||||
self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
|
||||
type=np.uint8, BGR format).
|
||||
conf_threshold: confidence threshold.
|
||||
nms_threshold: nms threshold.
|
||||
use_origin_size: whether to use origin size.
|
||||
Returns:
|
||||
final_bounding_boxes: list of np.array ([n_boxes, 5],
|
||||
type=np.float32).
|
||||
final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
|
||||
"""
|
||||
# self.t['forward_pass'].tic()
|
||||
frames, self.resize = self.batched_transform(frames, use_origin_size)
|
||||
frames = frames.to(self.device)
|
||||
frames = frames - self.mean_tensor
|
||||
|
||||
b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
|
||||
|
||||
final_bounding_boxes, final_landmarks = [], []
|
||||
|
||||
# decode
|
||||
priors = priors.unsqueeze(0)
|
||||
b_loc = (
|
||||
batched_decode(b_loc, priors, self.cfg["variance"])
|
||||
* self.scale
|
||||
/ self.resize
|
||||
)
|
||||
b_landmarks = (
|
||||
batched_decode_landm(b_landmarks, priors, self.cfg["variance"])
|
||||
* self.scale1
|
||||
/ self.resize
|
||||
)
|
||||
b_conf = b_conf[:, :, 1]
|
||||
|
||||
# index for selection
|
||||
b_indice = b_conf > conf_threshold
|
||||
|
||||
# concat
|
||||
b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
|
||||
|
||||
for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
|
||||
# ignore low scores
|
||||
pred, landm = pred[inds, :], landm[inds, :]
|
||||
if pred.shape[0] == 0:
|
||||
final_bounding_boxes.append(np.array([], dtype=np.float32))
|
||||
final_landmarks.append(np.array([], dtype=np.float32))
|
||||
continue
|
||||
|
||||
# sort
|
||||
# order = score.argsort(descending=True)
|
||||
# box, landm, score = box[order], landm[order], score[order]
|
||||
|
||||
# to CPU
|
||||
bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
|
||||
|
||||
# NMS
|
||||
keep = py_cpu_nms(bounding_boxes, nms_threshold)
|
||||
bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
|
||||
|
||||
# append
|
||||
final_bounding_boxes.append(bounding_boxes)
|
||||
final_landmarks.append(landmarks)
|
||||
# self.t['forward_pass'].toc(average=True)
|
||||
# self.batch_time += self.t['forward_pass'].diff
|
||||
# self.total_frame += len(frames)
|
||||
# print(self.batch_time / self.total_frame)
|
||||
|
||||
return final_bounding_boxes, final_landmarks
|
196
iopaint/plugins/facexlib/detection/retinaface_net.py
Normal file
196
iopaint/plugins/facexlib/detection/retinaface_net.py
Normal file
@ -0,0 +1,196 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride=1, leaky=0):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||
|
||||
|
||||
def conv_bn_no_relu(inp, oup, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
|
||||
def conv_bn1X1(inp, oup, stride, leaky=0):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||
|
||||
|
||||
def conv_dw(inp, oup, stride, leaky=0.1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||
)
|
||||
|
||||
|
||||
class SSH(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel):
|
||||
super(SSH, self).__init__()
|
||||
assert out_channel % 4 == 0
|
||||
leaky = 0
|
||||
if (out_channel <= 64):
|
||||
leaky = 0.1
|
||||
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
|
||||
|
||||
self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
|
||||
self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
|
||||
|
||||
self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
|
||||
self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
|
||||
|
||||
def forward(self, input):
|
||||
conv3X3 = self.conv3X3(input)
|
||||
|
||||
conv5X5_1 = self.conv5X5_1(input)
|
||||
conv5X5 = self.conv5X5_2(conv5X5_1)
|
||||
|
||||
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
||||
conv7X7 = self.conv7x7_3(conv7X7_2)
|
||||
|
||||
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_list, out_channels):
|
||||
super(FPN, self).__init__()
|
||||
leaky = 0
|
||||
if (out_channels <= 64):
|
||||
leaky = 0.1
|
||||
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
|
||||
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
|
||||
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
|
||||
|
||||
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||
|
||||
def forward(self, input):
|
||||
# names = list(input.keys())
|
||||
# input = list(input.values())
|
||||
|
||||
output1 = self.output1(input[0])
|
||||
output2 = self.output2(input[1])
|
||||
output3 = self.output3(input[2])
|
||||
|
||||
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
|
||||
output2 = output2 + up3
|
||||
output2 = self.merge2(output2)
|
||||
|
||||
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
|
||||
output1 = output1 + up2
|
||||
output1 = self.merge1(output1)
|
||||
|
||||
out = [output1, output2, output3]
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV1(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.stage1 = nn.Sequential(
|
||||
conv_bn(3, 8, 2, leaky=0.1), # 3
|
||||
conv_dw(8, 16, 1), # 7
|
||||
conv_dw(16, 32, 2), # 11
|
||||
conv_dw(32, 32, 1), # 19
|
||||
conv_dw(32, 64, 2), # 27
|
||||
conv_dw(64, 64, 1), # 43
|
||||
)
|
||||
self.stage2 = nn.Sequential(
|
||||
conv_dw(64, 128, 2), # 43 + 16 = 59
|
||||
conv_dw(128, 128, 1), # 59 + 32 = 91
|
||||
conv_dw(128, 128, 1), # 91 + 32 = 123
|
||||
conv_dw(128, 128, 1), # 123 + 32 = 155
|
||||
conv_dw(128, 128, 1), # 155 + 32 = 187
|
||||
conv_dw(128, 128, 1), # 187 + 32 = 219
|
||||
)
|
||||
self.stage3 = nn.Sequential(
|
||||
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
||||
conv_dw(256, 256, 1), # 241 + 64 = 301
|
||||
)
|
||||
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(256, 1000)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stage1(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.avg(x)
|
||||
# x = self.model(x)
|
||||
x = x.view(-1, 256)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class ClassHead(nn.Module):
|
||||
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(ClassHead, self).__init__()
|
||||
self.num_anchors = num_anchors
|
||||
self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = out.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
return out.view(out.shape[0], -1, 2)
|
||||
|
||||
|
||||
class BboxHead(nn.Module):
|
||||
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(BboxHead, self).__init__()
|
||||
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = out.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
return out.view(out.shape[0], -1, 4)
|
||||
|
||||
|
||||
class LandmarkHead(nn.Module):
|
||||
|
||||
def __init__(self, inchannels=512, num_anchors=3):
|
||||
super(LandmarkHead, self).__init__()
|
||||
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1x1(x)
|
||||
out = out.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
return out.view(out.shape[0], -1, 10)
|
||||
|
||||
|
||||
def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
|
||||
classhead = nn.ModuleList()
|
||||
for i in range(fpn_num):
|
||||
classhead.append(ClassHead(inchannels, anchor_num))
|
||||
return classhead
|
||||
|
||||
|
||||
def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
|
||||
bboxhead = nn.ModuleList()
|
||||
for i in range(fpn_num):
|
||||
bboxhead.append(BboxHead(inchannels, anchor_num))
|
||||
return bboxhead
|
||||
|
||||
|
||||
def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
|
||||
landmarkhead = nn.ModuleList()
|
||||
for i in range(fpn_num):
|
||||
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
|
||||
return landmarkhead
|
421
iopaint/plugins/facexlib/detection/retinaface_utils.py
Normal file
421
iopaint/plugins/facexlib/detection/retinaface_utils.py
Normal file
@ -0,0 +1,421 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from itertools import product as product
|
||||
from math import ceil
|
||||
|
||||
|
||||
class PriorBox(object):
|
||||
|
||||
def __init__(self, cfg, image_size=None, phase='train'):
|
||||
super(PriorBox, self).__init__()
|
||||
self.min_sizes = cfg['min_sizes']
|
||||
self.steps = cfg['steps']
|
||||
self.clip = cfg['clip']
|
||||
self.image_size = image_size
|
||||
self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
|
||||
self.name = 's'
|
||||
|
||||
def forward(self):
|
||||
anchors = []
|
||||
for k, f in enumerate(self.feature_maps):
|
||||
min_sizes = self.min_sizes[k]
|
||||
for i, j in product(range(f[0]), range(f[1])):
|
||||
for min_size in min_sizes:
|
||||
s_kx = min_size / self.image_size[1]
|
||||
s_ky = min_size / self.image_size[0]
|
||||
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
|
||||
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
|
||||
for cy, cx in product(dense_cy, dense_cx):
|
||||
anchors += [cx, cy, s_kx, s_ky]
|
||||
|
||||
# back to torch land
|
||||
output = torch.Tensor(anchors).view(-1, 4)
|
||||
if self.clip:
|
||||
output.clamp_(max=1, min=0)
|
||||
return output
|
||||
|
||||
|
||||
def py_cpu_nms(dets, thresh):
|
||||
"""Pure Python NMS baseline."""
|
||||
keep = torchvision.ops.nms(
|
||||
boxes=torch.Tensor(dets[:, :4]),
|
||||
scores=torch.Tensor(dets[:, 4]),
|
||||
iou_threshold=thresh,
|
||||
)
|
||||
|
||||
return list(keep)
|
||||
|
||||
|
||||
def point_form(boxes):
|
||||
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
|
||||
representation for comparison to point form ground truth data.
|
||||
Args:
|
||||
boxes: (tensor) center-size default boxes from priorbox layers.
|
||||
Return:
|
||||
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
||||
"""
|
||||
return torch.cat(
|
||||
(
|
||||
boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
|
||||
boxes[:, :2] + boxes[:, 2:] / 2),
|
||||
1) # xmax, ymax
|
||||
|
||||
|
||||
def center_size(boxes):
|
||||
""" Convert prior_boxes to (cx, cy, w, h)
|
||||
representation for comparison to center-size form ground truth data.
|
||||
Args:
|
||||
boxes: (tensor) point_form boxes
|
||||
Return:
|
||||
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
||||
"""
|
||||
return torch.cat(
|
||||
(boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
|
||||
boxes[:, 2:] - boxes[:, :2],
|
||||
1) # w, h
|
||||
|
||||
|
||||
def intersect(box_a, box_b):
|
||||
""" We resize both tensors to [A,B,2] without new malloc:
|
||||
[A,2] -> [A,1,2] -> [A,B,2]
|
||||
[B,2] -> [1,B,2] -> [A,B,2]
|
||||
Then we compute the area of intersect between box_a and box_b.
|
||||
Args:
|
||||
box_a: (tensor) bounding boxes, Shape: [A,4].
|
||||
box_b: (tensor) bounding boxes, Shape: [B,4].
|
||||
Return:
|
||||
(tensor) intersection area, Shape: [A,B].
|
||||
"""
|
||||
A = box_a.size(0)
|
||||
B = box_b.size(0)
|
||||
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
||||
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
||||
inter = torch.clamp((max_xy - min_xy), min=0)
|
||||
return inter[:, :, 0] * inter[:, :, 1]
|
||||
|
||||
|
||||
def jaccard(box_a, box_b):
|
||||
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
||||
is simply the intersection over union of two boxes. Here we operate on
|
||||
ground truth boxes and default boxes.
|
||||
E.g.:
|
||||
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
||||
Args:
|
||||
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
|
||||
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
|
||||
Return:
|
||||
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
|
||||
"""
|
||||
inter = intersect(box_a, box_b)
|
||||
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
||||
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
||||
union = area_a + area_b - inter
|
||||
return inter / union # [A,B]
|
||||
|
||||
|
||||
def matrix_iou(a, b):
|
||||
"""
|
||||
return iou of a and b, numpy version for data augenmentation
|
||||
"""
|
||||
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
||||
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
||||
|
||||
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
||||
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
||||
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
|
||||
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
|
||||
|
||||
|
||||
def matrix_iof(a, b):
|
||||
"""
|
||||
return iof of a and b, numpy version for data augenmentation
|
||||
"""
|
||||
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
||||
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
||||
|
||||
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
||||
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
||||
return area_i / np.maximum(area_a[:, np.newaxis], 1)
|
||||
|
||||
|
||||
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
|
||||
"""Match each prior box with the ground truth box of the highest jaccard
|
||||
overlap, encode the bounding boxes, then return the matched indices
|
||||
corresponding to both confidence and location preds.
|
||||
Args:
|
||||
threshold: (float) The overlap threshold used when matching boxes.
|
||||
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
|
||||
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
|
||||
variances: (tensor) Variances corresponding to each prior coord,
|
||||
Shape: [num_priors, 4].
|
||||
labels: (tensor) All the class labels for the image, Shape: [num_obj].
|
||||
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
|
||||
loc_t: (tensor) Tensor to be filled w/ encoded location targets.
|
||||
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
|
||||
landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
|
||||
idx: (int) current batch index
|
||||
Return:
|
||||
The matched indices corresponding to 1)location 2)confidence
|
||||
3)landm preds.
|
||||
"""
|
||||
# jaccard index
|
||||
overlaps = jaccard(truths, point_form(priors))
|
||||
# (Bipartite Matching)
|
||||
# [1,num_objects] best prior for each ground truth
|
||||
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
|
||||
|
||||
# ignore hard gt
|
||||
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
|
||||
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
|
||||
if best_prior_idx_filter.shape[0] <= 0:
|
||||
loc_t[idx] = 0
|
||||
conf_t[idx] = 0
|
||||
return
|
||||
|
||||
# [1,num_priors] best ground truth for each prior
|
||||
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
|
||||
best_truth_idx.squeeze_(0)
|
||||
best_truth_overlap.squeeze_(0)
|
||||
best_prior_idx.squeeze_(1)
|
||||
best_prior_idx_filter.squeeze_(1)
|
||||
best_prior_overlap.squeeze_(1)
|
||||
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
|
||||
# TODO refactor: index best_prior_idx with long tensor
|
||||
# ensure every gt matches with its prior of max overlap
|
||||
for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
|
||||
best_truth_idx[best_prior_idx[j]] = j
|
||||
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
|
||||
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
|
||||
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
|
||||
loc = encode(matches, priors, variances)
|
||||
|
||||
matches_landm = landms[best_truth_idx]
|
||||
landm = encode_landm(matches_landm, priors, variances)
|
||||
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
|
||||
conf_t[idx] = conf # [num_priors] top class label for each prior
|
||||
landm_t[idx] = landm
|
||||
|
||||
|
||||
def encode(matched, priors, variances):
|
||||
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||
we have matched (based on jaccard overlap) with the prior boxes.
|
||||
Args:
|
||||
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||
Shape: [num_priors, 4].
|
||||
priors: (tensor) Prior boxes in center-offset form
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
encoded boxes (tensor), Shape: [num_priors, 4]
|
||||
"""
|
||||
|
||||
# dist b/t match center and prior's center
|
||||
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
||||
# encode variance
|
||||
g_cxcy /= (variances[0] * priors[:, 2:])
|
||||
# match wh / prior wh
|
||||
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
||||
g_wh = torch.log(g_wh) / variances[1]
|
||||
# return target for smooth_l1_loss
|
||||
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
||||
|
||||
|
||||
def encode_landm(matched, priors, variances):
|
||||
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||
we have matched (based on jaccard overlap) with the prior boxes.
|
||||
Args:
|
||||
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||
Shape: [num_priors, 10].
|
||||
priors: (tensor) Prior boxes in center-offset form
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
encoded landm (tensor), Shape: [num_priors, 10]
|
||||
"""
|
||||
|
||||
# dist b/t match center and prior's center
|
||||
matched = torch.reshape(matched, (matched.size(0), 5, 2))
|
||||
priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||
priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||
priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||
priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
|
||||
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
|
||||
# encode variance
|
||||
g_cxcy /= (variances[0] * priors[:, :, 2:])
|
||||
# g_cxcy /= priors[:, :, 2:]
|
||||
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
|
||||
# return target for smooth_l1_loss
|
||||
return g_cxcy
|
||||
|
||||
|
||||
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
||||
def decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
return boxes
|
||||
|
||||
|
||||
def decode_landm(pre, priors, variances):
|
||||
"""Decode landm from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
pre (tensor): landm predictions for loc layers,
|
||||
Shape: [num_priors,10]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded landm predictions
|
||||
"""
|
||||
tmp = (
|
||||
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
|
||||
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
|
||||
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
|
||||
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
|
||||
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
|
||||
)
|
||||
landms = torch.cat(tmp, dim=1)
|
||||
return landms
|
||||
|
||||
|
||||
def batched_decode(b_loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
b_loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_batches,num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [1,num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
boxes = (
|
||||
priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||
priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
|
||||
)
|
||||
boxes = torch.cat(boxes, dim=2)
|
||||
|
||||
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
||||
boxes[:, :, 2:] += boxes[:, :, :2]
|
||||
return boxes
|
||||
|
||||
|
||||
def batched_decode_landm(pre, priors, variances):
|
||||
"""Decode landm from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
pre (tensor): landm predictions for loc layers,
|
||||
Shape: [num_batches,num_priors,10]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [1,num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded landm predictions
|
||||
"""
|
||||
landms = (
|
||||
priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||
priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
|
||||
priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
|
||||
priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
|
||||
priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
|
||||
)
|
||||
landms = torch.cat(landms, dim=2)
|
||||
return landms
|
||||
|
||||
|
||||
def log_sum_exp(x):
|
||||
"""Utility function for computing log_sum_exp while determining
|
||||
This will be used to determine unaveraged confidence loss across
|
||||
all examples in a batch.
|
||||
Args:
|
||||
x (Variable(tensor)): conf_preds from conf layers
|
||||
"""
|
||||
x_max = x.data.max()
|
||||
return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
|
||||
|
||||
|
||||
# Original author: Francisco Massa:
|
||||
# https://github.com/fmassa/object-detection.torch
|
||||
# Ported to PyTorch by Max deGroot (02/01/2017)
|
||||
def nms(boxes, scores, overlap=0.5, top_k=200):
|
||||
"""Apply non-maximum suppression at test time to avoid detecting too many
|
||||
overlapping bounding boxes for a given object.
|
||||
Args:
|
||||
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
||||
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
||||
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
||||
top_k: (int) The Maximum number of box preds to consider.
|
||||
Return:
|
||||
The indices of the kept boxes with respect to num_priors.
|
||||
"""
|
||||
|
||||
keep = torch.Tensor(scores.size(0)).fill_(0).long()
|
||||
if boxes.numel() == 0:
|
||||
return keep
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
area = torch.mul(x2 - x1, y2 - y1)
|
||||
v, idx = scores.sort(0) # sort in ascending order
|
||||
# I = I[v >= 0.01]
|
||||
idx = idx[-top_k:] # indices of the top-k largest vals
|
||||
xx1 = boxes.new()
|
||||
yy1 = boxes.new()
|
||||
xx2 = boxes.new()
|
||||
yy2 = boxes.new()
|
||||
w = boxes.new()
|
||||
h = boxes.new()
|
||||
|
||||
# keep = torch.Tensor()
|
||||
count = 0
|
||||
while idx.numel() > 0:
|
||||
i = idx[-1] # index of current largest val
|
||||
# keep.append(i)
|
||||
keep[count] = i
|
||||
count += 1
|
||||
if idx.size(0) == 1:
|
||||
break
|
||||
idx = idx[:-1] # remove kept element from view
|
||||
# load bboxes of next highest vals
|
||||
torch.index_select(x1, 0, idx, out=xx1)
|
||||
torch.index_select(y1, 0, idx, out=yy1)
|
||||
torch.index_select(x2, 0, idx, out=xx2)
|
||||
torch.index_select(y2, 0, idx, out=yy2)
|
||||
# store element-wise max with next highest score
|
||||
xx1 = torch.clamp(xx1, min=x1[i])
|
||||
yy1 = torch.clamp(yy1, min=y1[i])
|
||||
xx2 = torch.clamp(xx2, max=x2[i])
|
||||
yy2 = torch.clamp(yy2, max=y2[i])
|
||||
w.resize_as_(xx2)
|
||||
h.resize_as_(yy2)
|
||||
w = xx2 - xx1
|
||||
h = yy2 - yy1
|
||||
# check sizes of xx1 and xx2.. after each iteration
|
||||
w = torch.clamp(w, min=0.0)
|
||||
h = torch.clamp(h, min=0.0)
|
||||
inter = w * h
|
||||
# IoU = i / (area(a) + area(b) - i)
|
||||
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
||||
union = (rem_areas - inter) + area[i]
|
||||
IoU = inter / union # store result in iou
|
||||
# keep only elements with an IoU <= overlap
|
||||
idx = idx[IoU.le(overlap)]
|
||||
return keep, count
|
24
iopaint/plugins/facexlib/parsing/__init__.py
Normal file
24
iopaint/plugins/facexlib/parsing/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
from ..utils import load_file_from_url
|
||||
from .bisenet import BiSeNet
|
||||
from .parsenet import ParseNet
|
||||
|
||||
|
||||
def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None):
|
||||
if model_name == 'bisenet':
|
||||
model = BiSeNet(num_class=19)
|
||||
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth'
|
||||
elif model_name == 'parsenet':
|
||||
model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
|
||||
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth'
|
||||
else:
|
||||
raise NotImplementedError(f'{model_name} is not implemented.')
|
||||
|
||||
model_path = load_file_from_url(
|
||||
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
||||
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(load_net, strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
return model
|
140
iopaint/plugins/facexlib/parsing/bisenet.py
Normal file
140
iopaint/plugins/facexlib/parsing/bisenet.py
Normal file
@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .resnet import ResNet18
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_chan)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(self.bn(x))
|
||||
return x
|
||||
|
||||
|
||||
class BiSeNetOutput(nn.Module):
|
||||
|
||||
def __init__(self, in_chan, mid_chan, num_class):
|
||||
super(BiSeNetOutput, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv(x)
|
||||
out = self.conv_out(feat)
|
||||
return out, feat
|
||||
|
||||
|
||||
class AttentionRefinementModule(nn.Module):
|
||||
|
||||
def __init__(self, in_chan, out_chan):
|
||||
super(AttentionRefinementModule, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
|
||||
self.bn_atten = nn.BatchNorm2d(out_chan)
|
||||
self.sigmoid_atten = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv(x)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv_atten(atten)
|
||||
atten = self.bn_atten(atten)
|
||||
atten = self.sigmoid_atten(atten)
|
||||
out = torch.mul(feat, atten)
|
||||
return out
|
||||
|
||||
|
||||
class ContextPath(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ContextPath, self).__init__()
|
||||
self.resnet = ResNet18()
|
||||
self.arm16 = AttentionRefinementModule(256, 128)
|
||||
self.arm32 = AttentionRefinementModule(512, 128)
|
||||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
feat8, feat16, feat32 = self.resnet(x)
|
||||
h8, w8 = feat8.size()[2:]
|
||||
h16, w16 = feat16.size()[2:]
|
||||
h32, w32 = feat32.size()[2:]
|
||||
|
||||
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||||
avg = self.conv_avg(avg)
|
||||
avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
|
||||
|
||||
feat32_arm = self.arm32(feat32)
|
||||
feat32_sum = feat32_arm + avg_up
|
||||
feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
|
||||
feat32_up = self.conv_head32(feat32_up)
|
||||
|
||||
feat16_arm = self.arm16(feat16)
|
||||
feat16_sum = feat16_arm + feat32_up
|
||||
feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
|
||||
feat16_up = self.conv_head16(feat16_up)
|
||||
|
||||
return feat8, feat16_up, feat32_up # x8, x8, x16
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
|
||||
def __init__(self, in_chan, out_chan):
|
||||
super(FeatureFusionModule, self).__init__()
|
||||
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||||
self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, fsp, fcp):
|
||||
fcat = torch.cat([fsp, fcp], dim=1)
|
||||
feat = self.convblk(fcat)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv1(atten)
|
||||
atten = self.relu(atten)
|
||||
atten = self.conv2(atten)
|
||||
atten = self.sigmoid(atten)
|
||||
feat_atten = torch.mul(feat, atten)
|
||||
feat_out = feat_atten + feat
|
||||
return feat_out
|
||||
|
||||
|
||||
class BiSeNet(nn.Module):
|
||||
|
||||
def __init__(self, num_class):
|
||||
super(BiSeNet, self).__init__()
|
||||
self.cp = ContextPath()
|
||||
self.ffm = FeatureFusionModule(256, 256)
|
||||
self.conv_out = BiSeNetOutput(256, 256, num_class)
|
||||
self.conv_out16 = BiSeNetOutput(128, 64, num_class)
|
||||
self.conv_out32 = BiSeNetOutput(128, 64, num_class)
|
||||
|
||||
def forward(self, x, return_feat=False):
|
||||
h, w = x.size()[2:]
|
||||
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
|
||||
feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
|
||||
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
||||
|
||||
out, feat = self.conv_out(feat_fuse)
|
||||
out16, feat16 = self.conv_out16(feat_cp8)
|
||||
out32, feat32 = self.conv_out32(feat_cp16)
|
||||
|
||||
out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
|
||||
out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
|
||||
out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
|
||||
|
||||
if return_feat:
|
||||
feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
|
||||
feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
|
||||
feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
|
||||
return out, out16, out32, feat, feat16, feat32
|
||||
else:
|
||||
return out, out16, out32
|
194
iopaint/plugins/facexlib/parsing/parsenet.py
Normal file
194
iopaint/plugins/facexlib/parsing/parsenet.py
Normal file
@ -0,0 +1,194 @@
|
||||
"""Modified from https://github.com/chaofengc/PSFRGAN
|
||||
"""
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class NormLayer(nn.Module):
|
||||
"""Normalization Layers.
|
||||
|
||||
Args:
|
||||
channels: input channels, for batch norm and instance norm.
|
||||
input_size: input shape without batch size, for layer norm.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
|
||||
super(NormLayer, self).__init__()
|
||||
norm_type = norm_type.lower()
|
||||
self.norm_type = norm_type
|
||||
if norm_type == 'bn':
|
||||
self.norm = nn.BatchNorm2d(channels, affine=True)
|
||||
elif norm_type == 'in':
|
||||
self.norm = nn.InstanceNorm2d(channels, affine=False)
|
||||
elif norm_type == 'gn':
|
||||
self.norm = nn.GroupNorm(32, channels, affine=True)
|
||||
elif norm_type == 'pixel':
|
||||
self.norm = lambda x: F.normalize(x, p=2, dim=1)
|
||||
elif norm_type == 'layer':
|
||||
self.norm = nn.LayerNorm(normalize_shape)
|
||||
elif norm_type == 'none':
|
||||
self.norm = lambda x: x * 1.0
|
||||
else:
|
||||
assert 1 == 0, f'Norm type {norm_type} not support.'
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
if self.norm_type == 'spade':
|
||||
return self.norm(x, ref)
|
||||
else:
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class ReluLayer(nn.Module):
|
||||
"""Relu Layer.
|
||||
|
||||
Args:
|
||||
relu type: type of relu layer, candidates are
|
||||
- ReLU
|
||||
- LeakyReLU: default relu slope 0.2
|
||||
- PRelu
|
||||
- SELU
|
||||
- none: direct pass
|
||||
"""
|
||||
|
||||
def __init__(self, channels, relu_type='relu'):
|
||||
super(ReluLayer, self).__init__()
|
||||
relu_type = relu_type.lower()
|
||||
if relu_type == 'relu':
|
||||
self.func = nn.ReLU(True)
|
||||
elif relu_type == 'leakyrelu':
|
||||
self.func = nn.LeakyReLU(0.2, inplace=True)
|
||||
elif relu_type == 'prelu':
|
||||
self.func = nn.PReLU(channels)
|
||||
elif relu_type == 'selu':
|
||||
self.func = nn.SELU(True)
|
||||
elif relu_type == 'none':
|
||||
self.func = lambda x: x * 1.0
|
||||
else:
|
||||
assert 1 == 0, f'Relu type {relu_type} not support.'
|
||||
|
||||
def forward(self, x):
|
||||
return self.func(x)
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
scale='none',
|
||||
norm_type='none',
|
||||
relu_type='none',
|
||||
use_pad=True,
|
||||
bias=True):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.use_pad = use_pad
|
||||
self.norm_type = norm_type
|
||||
if norm_type in ['bn']:
|
||||
bias = False
|
||||
|
||||
stride = 2 if scale == 'down' else 1
|
||||
|
||||
self.scale_func = lambda x: x
|
||||
if scale == 'up':
|
||||
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
|
||||
|
||||
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
|
||||
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
|
||||
|
||||
self.relu = ReluLayer(out_channels, relu_type)
|
||||
self.norm = NormLayer(out_channels, norm_type=norm_type)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.scale_func(x)
|
||||
if self.use_pad:
|
||||
out = self.reflection_pad(out)
|
||||
out = self.conv2d(out)
|
||||
out = self.norm(out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""
|
||||
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
|
||||
"""
|
||||
|
||||
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if scale == 'none' and c_in == c_out:
|
||||
self.shortcut_func = lambda x: x
|
||||
else:
|
||||
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
|
||||
|
||||
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
|
||||
scale_conf = scale_config_dict[scale]
|
||||
|
||||
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
|
||||
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
|
||||
|
||||
def forward(self, x):
|
||||
identity = self.shortcut_func(x)
|
||||
|
||||
res = self.conv1(x)
|
||||
res = self.conv2(res)
|
||||
return identity + res
|
||||
|
||||
|
||||
class ParseNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_size=128,
|
||||
out_size=128,
|
||||
min_feat_size=32,
|
||||
base_ch=64,
|
||||
parsing_ch=19,
|
||||
res_depth=10,
|
||||
relu_type='LeakyReLU',
|
||||
norm_type='bn',
|
||||
ch_range=[32, 256]):
|
||||
super().__init__()
|
||||
self.res_depth = res_depth
|
||||
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
|
||||
min_ch, max_ch = ch_range
|
||||
|
||||
ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
|
||||
min_feat_size = min(in_size, min_feat_size)
|
||||
|
||||
down_steps = int(np.log2(in_size // min_feat_size))
|
||||
up_steps = int(np.log2(out_size // min_feat_size))
|
||||
|
||||
# =============== define encoder-body-decoder ====================
|
||||
self.encoder = []
|
||||
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
|
||||
head_ch = base_ch
|
||||
for i in range(down_steps):
|
||||
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
|
||||
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
|
||||
head_ch = head_ch * 2
|
||||
|
||||
self.body = []
|
||||
for i in range(res_depth):
|
||||
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
|
||||
|
||||
self.decoder = []
|
||||
for i in range(up_steps):
|
||||
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
|
||||
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
|
||||
head_ch = head_ch // 2
|
||||
|
||||
self.encoder = nn.Sequential(*self.encoder)
|
||||
self.body = nn.Sequential(*self.body)
|
||||
self.decoder = nn.Sequential(*self.decoder)
|
||||
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
|
||||
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.encoder(x)
|
||||
x = feat + self.body(feat)
|
||||
x = self.decoder(x)
|
||||
out_img = self.out_img_conv(x)
|
||||
out_mask = self.out_mask_conv(x)
|
||||
return out_mask, out_img
|
69
iopaint/plugins/facexlib/parsing/resnet.py
Normal file
69
iopaint/plugins/facexlib/parsing/resnet.py
Normal file
@ -0,0 +1,69 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_chan, out_chan, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
||||
self.bn1 = nn.BatchNorm2d(out_chan)
|
||||
self.conv2 = conv3x3(out_chan, out_chan)
|
||||
self.bn2 = nn.BatchNorm2d(out_chan)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
if in_chan != out_chan or stride != 1:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_chan),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.conv1(x)
|
||||
residual = F.relu(self.bn1(residual))
|
||||
residual = self.conv2(residual)
|
||||
residual = self.bn2(residual)
|
||||
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out = shortcut + residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
||||
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
||||
for i in range(bnum - 1):
|
||||
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResNet18(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ResNet18, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
||||
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
||||
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
||||
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(self.bn1(x))
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
feat8 = self.layer2(x) # 1/8
|
||||
feat16 = self.layer3(feat8) # 1/16
|
||||
feat32 = self.layer4(feat16) # 1/32
|
||||
return feat8, feat16, feat32
|
7
iopaint/plugins/facexlib/utils/__init__.py
Normal file
7
iopaint/plugins/facexlib/utils/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
|
||||
from .misc import img2tensor, load_file_from_url, scandir
|
||||
|
||||
__all__ = [
|
||||
'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back',
|
||||
'img2tensor', 'scandir'
|
||||
]
|
473
iopaint/plugins/facexlib/utils/face_restoration_helper.py
Normal file
473
iopaint/plugins/facexlib/utils/face_restoration_helper.py
Normal file
@ -0,0 +1,473 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from ..detection import init_detection_model
|
||||
from ..parsing import init_parsing_model
|
||||
from ..utils.misc import img2tensor, imwrite
|
||||
|
||||
|
||||
def get_largest_face(det_faces, h, w):
|
||||
def get_location(val, length):
|
||||
if val < 0:
|
||||
return 0
|
||||
elif val > length:
|
||||
return length
|
||||
else:
|
||||
return val
|
||||
|
||||
face_areas = []
|
||||
for det_face in det_faces:
|
||||
left = get_location(det_face[0], w)
|
||||
right = get_location(det_face[2], w)
|
||||
top = get_location(det_face[1], h)
|
||||
bottom = get_location(det_face[3], h)
|
||||
face_area = (right - left) * (bottom - top)
|
||||
face_areas.append(face_area)
|
||||
largest_idx = face_areas.index(max(face_areas))
|
||||
return det_faces[largest_idx], largest_idx
|
||||
|
||||
|
||||
def get_center_face(det_faces, h=0, w=0, center=None):
|
||||
if center is not None:
|
||||
center = np.array(center)
|
||||
else:
|
||||
center = np.array([w / 2, h / 2])
|
||||
center_dist = []
|
||||
for det_face in det_faces:
|
||||
face_center = np.array(
|
||||
[(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]
|
||||
)
|
||||
dist = np.linalg.norm(face_center - center)
|
||||
center_dist.append(dist)
|
||||
center_idx = center_dist.index(min(center_dist))
|
||||
return det_faces[center_idx], center_idx
|
||||
|
||||
|
||||
class FaceRestoreHelper(object):
|
||||
"""Helper for the face restoration pipeline (base class)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upscale_factor,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model="retinaface_resnet50",
|
||||
save_ext="png",
|
||||
template_3points=False,
|
||||
pad_blur=False,
|
||||
use_parse=False,
|
||||
device=None,
|
||||
model_rootpath=None,
|
||||
):
|
||||
self.template_3points = template_3points # improve robustness
|
||||
self.upscale_factor = upscale_factor
|
||||
# the cropped face ratio based on the square face
|
||||
self.crop_ratio = crop_ratio # (h, w)
|
||||
assert (
|
||||
self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1
|
||||
), "crop ration only supports >=1"
|
||||
self.face_size = (
|
||||
int(face_size * self.crop_ratio[1]),
|
||||
int(face_size * self.crop_ratio[0]),
|
||||
)
|
||||
|
||||
if self.template_3points:
|
||||
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
||||
else:
|
||||
# standard 5 landmarks for FFHQ faces with 512 x 512
|
||||
self.face_template = np.array(
|
||||
[
|
||||
[192.98138, 239.94708],
|
||||
[318.90277, 240.1936],
|
||||
[256.63416, 314.01935],
|
||||
[201.26117, 371.41043],
|
||||
[313.08905, 371.15118],
|
||||
]
|
||||
)
|
||||
self.face_template = self.face_template * (face_size / 512.0)
|
||||
if self.crop_ratio[0] > 1:
|
||||
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
||||
if self.crop_ratio[1] > 1:
|
||||
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
||||
self.save_ext = save_ext
|
||||
self.pad_blur = pad_blur
|
||||
if self.pad_blur is True:
|
||||
self.template_3points = False
|
||||
|
||||
self.all_landmarks_5 = []
|
||||
self.det_faces = []
|
||||
self.affine_matrices = []
|
||||
self.inverse_affine_matrices = []
|
||||
self.cropped_faces = []
|
||||
self.restored_faces = []
|
||||
self.pad_input_imgs = []
|
||||
|
||||
if device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# init face detection model
|
||||
self.face_det = init_detection_model(
|
||||
det_model, half=False, device=self.device, model_rootpath=model_rootpath
|
||||
)
|
||||
|
||||
# init face parsing model
|
||||
self.use_parse = use_parse
|
||||
self.face_parse = init_parsing_model(
|
||||
model_name="parsenet", device=self.device, model_rootpath=model_rootpath
|
||||
)
|
||||
|
||||
def set_upscale_factor(self, upscale_factor):
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def read_image(self, img):
|
||||
"""img can be image path or cv2 loaded image."""
|
||||
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
||||
if isinstance(img, str):
|
||||
img = cv2.imread(img)
|
||||
|
||||
if np.max(img) > 256: # 16-bit image
|
||||
img = img / 65535 * 255
|
||||
if len(img.shape) == 2: # gray image
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||
img = img[:, :, 0:3]
|
||||
|
||||
self.input_img = img
|
||||
|
||||
def get_face_landmarks_5(
|
||||
self,
|
||||
only_keep_largest=False,
|
||||
only_center_face=False,
|
||||
resize=None,
|
||||
blur_ratio=0.01,
|
||||
eye_dist_threshold=None,
|
||||
):
|
||||
if resize is None:
|
||||
scale = 1
|
||||
input_img = self.input_img
|
||||
else:
|
||||
h, w = self.input_img.shape[0:2]
|
||||
scale = min(h, w) / resize
|
||||
h, w = int(h / scale), int(w / scale)
|
||||
input_img = cv2.resize(
|
||||
self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
|
||||
for bbox in bboxes:
|
||||
# remove faces with too small eye distance: side faces or too small faces
|
||||
eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
|
||||
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
||||
continue
|
||||
|
||||
if self.template_3points:
|
||||
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
||||
else:
|
||||
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
||||
self.all_landmarks_5.append(landmark)
|
||||
self.det_faces.append(bbox[0:5])
|
||||
if len(self.det_faces) == 0:
|
||||
return 0
|
||||
if only_keep_largest:
|
||||
h, w, _ = self.input_img.shape
|
||||
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
||||
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
||||
elif only_center_face:
|
||||
h, w, _ = self.input_img.shape
|
||||
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
||||
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
||||
|
||||
# pad blurry images
|
||||
if self.pad_blur:
|
||||
self.pad_input_imgs = []
|
||||
for landmarks in self.all_landmarks_5:
|
||||
# get landmarks
|
||||
eye_left = landmarks[0, :]
|
||||
eye_right = landmarks[1, :]
|
||||
eye_avg = (eye_left + eye_right) * 0.5
|
||||
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
||||
eye_to_eye = eye_right - eye_left
|
||||
eye_to_mouth = mouth_avg - eye_avg
|
||||
|
||||
# Get the oriented crop rectangle
|
||||
# x: half width of the oriented crop rectangle
|
||||
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
||||
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
||||
# norm with the hypotenuse: get the direction
|
||||
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
||||
rect_scale = 1.5
|
||||
x *= max(
|
||||
np.hypot(*eye_to_eye) * 2.0 * rect_scale,
|
||||
np.hypot(*eye_to_mouth) * 1.8 * rect_scale,
|
||||
)
|
||||
# y: half height of the oriented crop rectangle
|
||||
y = np.flipud(x) * [-1, 1]
|
||||
|
||||
# c: center
|
||||
c = eye_avg + eye_to_mouth * 0.1
|
||||
# quad: (left_top, left_bottom, right_bottom, right_top)
|
||||
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
||||
# qsize: side length of the square
|
||||
qsize = np.hypot(*x) * 2
|
||||
border = max(int(np.rint(qsize * 0.1)), 3)
|
||||
|
||||
# get pad
|
||||
# pad: (width_left, height_top, width_right, height_bottom)
|
||||
pad = (
|
||||
int(np.floor(min(quad[:, 0]))),
|
||||
int(np.floor(min(quad[:, 1]))),
|
||||
int(np.ceil(max(quad[:, 0]))),
|
||||
int(np.ceil(max(quad[:, 1]))),
|
||||
)
|
||||
pad = [
|
||||
max(-pad[0] + border, 1),
|
||||
max(-pad[1] + border, 1),
|
||||
max(pad[2] - self.input_img.shape[0] + border, 1),
|
||||
max(pad[3] - self.input_img.shape[1] + border, 1),
|
||||
]
|
||||
|
||||
if max(pad) > 1:
|
||||
# pad image
|
||||
pad_img = np.pad(
|
||||
self.input_img,
|
||||
((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
|
||||
"reflect",
|
||||
)
|
||||
# modify landmark coords
|
||||
landmarks[:, 0] += pad[0]
|
||||
landmarks[:, 1] += pad[1]
|
||||
# blur pad images
|
||||
h, w, _ = pad_img.shape
|
||||
y, x, _ = np.ogrid[:h, :w, :1]
|
||||
mask = np.maximum(
|
||||
1.0
|
||||
- np.minimum(
|
||||
np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]
|
||||
),
|
||||
1.0
|
||||
- np.minimum(
|
||||
np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]
|
||||
),
|
||||
)
|
||||
blur = int(qsize * blur_ratio)
|
||||
if blur % 2 == 0:
|
||||
blur += 1
|
||||
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
||||
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
||||
|
||||
pad_img = pad_img.astype("float32")
|
||||
pad_img += (blur_img - pad_img) * np.clip(
|
||||
mask * 3.0 + 1.0, 0.0, 1.0
|
||||
)
|
||||
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(
|
||||
mask, 0.0, 1.0
|
||||
)
|
||||
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
||||
self.pad_input_imgs.append(pad_img)
|
||||
else:
|
||||
self.pad_input_imgs.append(np.copy(self.input_img))
|
||||
|
||||
return len(self.all_landmarks_5)
|
||||
|
||||
def align_warp_face(self, save_cropped_path=None, border_mode="constant"):
|
||||
"""Align and warp faces with face template."""
|
||||
if self.pad_blur:
|
||||
assert (
|
||||
len(self.pad_input_imgs) == len(self.all_landmarks_5)
|
||||
), f"Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}"
|
||||
for idx, landmark in enumerate(self.all_landmarks_5):
|
||||
# use 5 landmarks to get affine matrix
|
||||
# use cv2.LMEDS method for the equivalence to skimage transform
|
||||
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
||||
affine_matrix = cv2.estimateAffinePartial2D(
|
||||
landmark, self.face_template, method=cv2.LMEDS
|
||||
)[0]
|
||||
self.affine_matrices.append(affine_matrix)
|
||||
# warp and crop faces
|
||||
if border_mode == "constant":
|
||||
border_mode = cv2.BORDER_CONSTANT
|
||||
elif border_mode == "reflect101":
|
||||
border_mode = cv2.BORDER_REFLECT101
|
||||
elif border_mode == "reflect":
|
||||
border_mode = cv2.BORDER_REFLECT
|
||||
if self.pad_blur:
|
||||
input_img = self.pad_input_imgs[idx]
|
||||
else:
|
||||
input_img = self.input_img
|
||||
cropped_face = cv2.warpAffine(
|
||||
input_img,
|
||||
affine_matrix,
|
||||
self.face_size,
|
||||
borderMode=border_mode,
|
||||
borderValue=(135, 133, 132),
|
||||
) # gray
|
||||
self.cropped_faces.append(cropped_face)
|
||||
# save the cropped face
|
||||
if save_cropped_path is not None:
|
||||
path = os.path.splitext(save_cropped_path)[0]
|
||||
save_path = f"{path}_{idx:02d}.{self.save_ext}"
|
||||
imwrite(cropped_face, save_path)
|
||||
|
||||
def get_inverse_affine(self, save_inverse_affine_path=None):
|
||||
"""Get inverse affine matrix."""
|
||||
for idx, affine_matrix in enumerate(self.affine_matrices):
|
||||
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
||||
inverse_affine *= self.upscale_factor
|
||||
self.inverse_affine_matrices.append(inverse_affine)
|
||||
# save inverse affine matrices
|
||||
if save_inverse_affine_path is not None:
|
||||
path, _ = os.path.splitext(save_inverse_affine_path)
|
||||
save_path = f"{path}_{idx:02d}.pth"
|
||||
torch.save(inverse_affine, save_path)
|
||||
|
||||
def add_restored_face(self, face):
|
||||
self.restored_faces.append(face)
|
||||
|
||||
def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
|
||||
h, w, _ = self.input_img.shape
|
||||
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
||||
|
||||
if upsample_img is None:
|
||||
# simply resize the background
|
||||
upsample_img = cv2.resize(
|
||||
self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
else:
|
||||
upsample_img = cv2.resize(
|
||||
upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
|
||||
assert len(self.restored_faces) == len(
|
||||
self.inverse_affine_matrices
|
||||
), "length of restored_faces and affine_matrices are different."
|
||||
for restored_face, inverse_affine in zip(
|
||||
self.restored_faces, self.inverse_affine_matrices
|
||||
):
|
||||
# Add an offset to inverse affine matrix, for more precise back alignment
|
||||
if self.upscale_factor > 1:
|
||||
extra_offset = 0.5 * self.upscale_factor
|
||||
else:
|
||||
extra_offset = 0
|
||||
inverse_affine[:, 2] += extra_offset
|
||||
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
||||
|
||||
if self.use_parse:
|
||||
# inference
|
||||
face_input = cv2.resize(
|
||||
restored_face, (512, 512), interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
face_input = img2tensor(
|
||||
face_input.astype("float32") / 255.0, bgr2rgb=True, float32=True
|
||||
)
|
||||
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
||||
with torch.no_grad():
|
||||
out = self.face_parse(face_input)[0]
|
||||
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
||||
|
||||
mask = np.zeros(out.shape)
|
||||
MASK_COLORMAP = [
|
||||
0,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
0,
|
||||
255,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
for idx, color in enumerate(MASK_COLORMAP):
|
||||
mask[out == idx] = color
|
||||
# blur the mask
|
||||
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
||||
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
||||
# remove the black borders
|
||||
thres = 10
|
||||
mask[:thres, :] = 0
|
||||
mask[-thres:, :] = 0
|
||||
mask[:, :thres] = 0
|
||||
mask[:, -thres:] = 0
|
||||
mask = mask / 255.0
|
||||
|
||||
mask = cv2.resize(mask, restored_face.shape[:2])
|
||||
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
|
||||
inv_soft_mask = mask[:, :, None]
|
||||
pasted_face = inv_restored
|
||||
|
||||
else: # use square parse maps
|
||||
mask = np.ones(self.face_size, dtype=np.float32)
|
||||
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
||||
# remove the black borders
|
||||
inv_mask_erosion = cv2.erode(
|
||||
inv_mask,
|
||||
np.ones(
|
||||
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)),
|
||||
np.uint8,
|
||||
),
|
||||
)
|
||||
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
||||
total_face_area = np.sum(inv_mask_erosion) # // 3
|
||||
# compute the fusion edge based on the area of face
|
||||
w_edge = int(total_face_area**0.5) // 20
|
||||
erosion_radius = w_edge * 2
|
||||
inv_mask_center = cv2.erode(
|
||||
inv_mask_erosion,
|
||||
np.ones((erosion_radius, erosion_radius), np.uint8),
|
||||
)
|
||||
blur_size = w_edge * 2
|
||||
inv_soft_mask = cv2.GaussianBlur(
|
||||
inv_mask_center, (blur_size + 1, blur_size + 1), 0
|
||||
)
|
||||
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
||||
upsample_img = upsample_img[:, :, None]
|
||||
inv_soft_mask = inv_soft_mask[:, :, None]
|
||||
|
||||
if (
|
||||
len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4
|
||||
): # alpha channel
|
||||
alpha = upsample_img[:, :, 3:]
|
||||
upsample_img = (
|
||||
inv_soft_mask * pasted_face
|
||||
+ (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
||||
)
|
||||
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
||||
else:
|
||||
upsample_img = (
|
||||
inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
||||
)
|
||||
|
||||
if np.max(upsample_img) > 256: # 16-bit image
|
||||
upsample_img = upsample_img.astype(np.uint16)
|
||||
else:
|
||||
upsample_img = upsample_img.astype(np.uint8)
|
||||
if save_path is not None:
|
||||
path = os.path.splitext(save_path)[0]
|
||||
save_path = f"{path}.{self.save_ext}"
|
||||
imwrite(upsample_img, save_path)
|
||||
return upsample_img
|
||||
|
||||
def clean_all(self):
|
||||
self.all_landmarks_5 = []
|
||||
self.restored_faces = []
|
||||
self.affine_matrices = []
|
||||
self.cropped_faces = []
|
||||
self.inverse_affine_matrices = []
|
||||
self.det_faces = []
|
||||
self.pad_input_imgs = []
|
208
iopaint/plugins/facexlib/utils/face_utils.py
Normal file
208
iopaint/plugins/facexlib/utils/face_utils.py
Normal file
@ -0,0 +1,208 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
|
||||
left, top, right, bot = bbox
|
||||
width = right - left
|
||||
height = bot - top
|
||||
|
||||
if preserve_aspect:
|
||||
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
|
||||
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
|
||||
else:
|
||||
width_increase = height_increase = increase_area
|
||||
left = int(left - width_increase * width)
|
||||
top = int(top - height_increase * height)
|
||||
right = int(right + width_increase * width)
|
||||
bot = int(bot + height_increase * height)
|
||||
return (left, top, right, bot)
|
||||
|
||||
|
||||
def get_valid_bboxes(bboxes, h, w):
|
||||
left = max(bboxes[0], 0)
|
||||
top = max(bboxes[1], 0)
|
||||
right = min(bboxes[2], w)
|
||||
bottom = min(bboxes[3], h)
|
||||
return (left, top, right, bottom)
|
||||
|
||||
|
||||
def align_crop_face_landmarks(img,
|
||||
landmarks,
|
||||
output_size,
|
||||
transform_size=None,
|
||||
enable_padding=True,
|
||||
return_inverse_affine=False,
|
||||
shrink_ratio=(1, 1)):
|
||||
"""Align and crop face with landmarks.
|
||||
|
||||
The output_size and transform_size are based on width. The height is
|
||||
adjusted based on shrink_ratio_h/shring_ration_w.
|
||||
|
||||
Modified from:
|
||||
https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image.
|
||||
landmarks (Numpy array): 5 or 68 or 98 landmarks.
|
||||
output_size (int): Output face size.
|
||||
transform_size (ing): Transform size. Usually the four time of
|
||||
output_size.
|
||||
enable_padding (float): Default: True.
|
||||
shrink_ratio (float | tuple[float] | list[float]): Shring the whole
|
||||
face for height and width (crop larger area). Default: (1, 1).
|
||||
|
||||
Returns:
|
||||
(Numpy array): Cropped face.
|
||||
"""
|
||||
lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
|
||||
|
||||
if isinstance(shrink_ratio, (float, int)):
|
||||
shrink_ratio = (shrink_ratio, shrink_ratio)
|
||||
if transform_size is None:
|
||||
transform_size = output_size * 4
|
||||
|
||||
# Parse landmarks
|
||||
lm = np.array(landmarks)
|
||||
if lm.shape[0] == 5 and lm_type == 'retinaface_5':
|
||||
eye_left = lm[0]
|
||||
eye_right = lm[1]
|
||||
mouth_avg = (lm[3] + lm[4]) * 0.5
|
||||
elif lm.shape[0] == 5 and lm_type == 'dlib_5':
|
||||
lm_eye_left = lm[2:4]
|
||||
lm_eye_right = lm[0:2]
|
||||
eye_left = np.mean(lm_eye_left, axis=0)
|
||||
eye_right = np.mean(lm_eye_right, axis=0)
|
||||
mouth_avg = lm[4]
|
||||
elif lm.shape[0] == 68:
|
||||
lm_eye_left = lm[36:42]
|
||||
lm_eye_right = lm[42:48]
|
||||
eye_left = np.mean(lm_eye_left, axis=0)
|
||||
eye_right = np.mean(lm_eye_right, axis=0)
|
||||
mouth_avg = (lm[48] + lm[54]) * 0.5
|
||||
elif lm.shape[0] == 98:
|
||||
lm_eye_left = lm[60:68]
|
||||
lm_eye_right = lm[68:76]
|
||||
eye_left = np.mean(lm_eye_left, axis=0)
|
||||
eye_right = np.mean(lm_eye_right, axis=0)
|
||||
mouth_avg = (lm[76] + lm[82]) * 0.5
|
||||
|
||||
eye_avg = (eye_left + eye_right) * 0.5
|
||||
eye_to_eye = eye_right - eye_left
|
||||
eye_to_mouth = mouth_avg - eye_avg
|
||||
|
||||
# Get the oriented crop rectangle
|
||||
# x: half width of the oriented crop rectangle
|
||||
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
||||
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
||||
# norm with the hypotenuse: get the direction
|
||||
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
||||
rect_scale = 1 # TODO: you can edit it to get larger rect
|
||||
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
||||
# y: half height of the oriented crop rectangle
|
||||
y = np.flipud(x) * [-1, 1]
|
||||
|
||||
x *= shrink_ratio[1] # width
|
||||
y *= shrink_ratio[0] # height
|
||||
|
||||
# c: center
|
||||
c = eye_avg + eye_to_mouth * 0.1
|
||||
# quad: (left_top, left_bottom, right_bottom, right_top)
|
||||
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
||||
# qsize: side length of the square
|
||||
qsize = np.hypot(*x) * 2
|
||||
|
||||
quad_ori = np.copy(quad)
|
||||
# Shrink, for large face
|
||||
# TODO: do we really need shrink
|
||||
shrink = int(np.floor(qsize / output_size * 0.5))
|
||||
if shrink > 1:
|
||||
h, w = img.shape[0:2]
|
||||
rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
|
||||
img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
|
||||
quad /= shrink
|
||||
qsize /= shrink
|
||||
|
||||
# Crop
|
||||
h, w = img.shape[0:2]
|
||||
border = max(int(np.rint(qsize * 0.1)), 3)
|
||||
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
||||
int(np.ceil(max(quad[:, 1]))))
|
||||
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
|
||||
if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
|
||||
img = img[crop[1]:crop[3], crop[0]:crop[2], :]
|
||||
quad -= crop[0:2]
|
||||
|
||||
# Pad
|
||||
# pad: (width_left, height_top, width_right, height_bottom)
|
||||
h, w = img.shape[0:2]
|
||||
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
||||
int(np.ceil(max(quad[:, 1]))))
|
||||
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
|
||||
if enable_padding and max(pad) > border - 4:
|
||||
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
||||
img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
||||
h, w = img.shape[0:2]
|
||||
y, x, _ = np.ogrid[:h, :w, :1]
|
||||
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
||||
np.float32(w - 1 - x) / pad[2]),
|
||||
1.0 - np.minimum(np.float32(y) / pad[1],
|
||||
np.float32(h - 1 - y) / pad[3]))
|
||||
blur = int(qsize * 0.02)
|
||||
if blur % 2 == 0:
|
||||
blur += 1
|
||||
blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
|
||||
|
||||
img = img.astype('float32')
|
||||
img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
||||
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
||||
img = np.clip(img, 0, 255) # float32, [0, 255]
|
||||
quad += pad[:2]
|
||||
|
||||
# Transform use cv2
|
||||
h_ratio = shrink_ratio[0] / shrink_ratio[1]
|
||||
dst_h, dst_w = int(transform_size * h_ratio), transform_size
|
||||
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
|
||||
# use cv2.LMEDS method for the equivalence to skimage transform
|
||||
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
||||
affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
|
||||
cropped_face = cv2.warpAffine(
|
||||
img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
|
||||
|
||||
if output_size < transform_size:
|
||||
cropped_face = cv2.resize(
|
||||
cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if return_inverse_affine:
|
||||
dst_h, dst_w = int(output_size * h_ratio), output_size
|
||||
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
|
||||
# use cv2.LMEDS method for the equivalence to skimage transform
|
||||
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
||||
affine_matrix = cv2.estimateAffinePartial2D(
|
||||
quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
|
||||
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
||||
else:
|
||||
inverse_affine = None
|
||||
return cropped_face, inverse_affine
|
||||
|
||||
|
||||
def paste_face_back(img, face, inverse_affine):
|
||||
h, w = img.shape[0:2]
|
||||
face_h, face_w = face.shape[0:2]
|
||||
inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
|
||||
mask = np.ones((face_h, face_w, 3), dtype=np.float32)
|
||||
inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
|
||||
# remove the black borders
|
||||
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
|
||||
inv_restored_remove_border = inv_mask_erosion * inv_restored
|
||||
total_face_area = np.sum(inv_mask_erosion) // 3
|
||||
# compute the fusion edge based on the area of face
|
||||
w_edge = int(total_face_area**0.5) // 20
|
||||
erosion_radius = w_edge * 2
|
||||
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
||||
blur_size = w_edge * 2
|
||||
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
||||
img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
|
||||
# float32, [0, 255]
|
||||
return img
|
118
iopaint/plugins/facexlib/utils/misc.py
Normal file
118
iopaint/plugins/facexlib/utils/misc.py
Normal file
@ -0,0 +1,118 @@
|
||||
import cv2
|
||||
import os
|
||||
import os.path as osp
|
||||
import torch
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from urllib.parse import urlparse
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||
"""Write image to file.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image array to be written.
|
||||
file_path (str): Image file path.
|
||||
params (None or list): Same as opencv's :func:`imwrite` interface.
|
||||
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
||||
whether to create it automatically.
|
||||
|
||||
Returns:
|
||||
bool: Successful or not.
|
||||
"""
|
||||
if auto_mkdir:
|
||||
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
return cv2.imwrite(file_path, img, params)
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
if save_dir is None:
|
||||
save_dir = os.path.join(ROOT_DIR, model_dir)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(save_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
Returns:
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if suffix is None:
|
||||
yield return_path
|
||||
elif return_path.endswith(suffix):
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
322
iopaint/plugins/gfpgan/archs/gfpganv1_clean_arch.py
Normal file
322
iopaint/plugins/gfpgan/archs/gfpganv1_clean_arch.py
Normal file
@ -0,0 +1,322 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
||||
super(StyleGAN2GeneratorCSFT, self).__init__(
|
||||
out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow)
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
conditions,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2GeneratorCSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block with bilinear upsampling/downsampling.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mode='down'):
|
||||
super(ResBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
||||
if mode == 'down':
|
||||
self.scale_factor = 0.5
|
||||
elif mode == 'up':
|
||||
self.scale_factor = 2
|
||||
|
||||
def forward(self, x):
|
||||
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
||||
# upsample/downsample
|
||||
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
||||
# skip
|
||||
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
skip = self.skip(x)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class GFPGANv1Clean(nn.Module):
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
|
||||
super(GFPGANv1Clean, self).__init__()
|
||||
self.input_is_latent = input_is_latent
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
'16': int(512 * unet_narrow),
|
||||
'32': int(512 * unet_narrow),
|
||||
'64': int(256 * channel_multiplier * unet_narrow),
|
||||
'128': int(128 * channel_multiplier * unet_narrow),
|
||||
'256': int(64 * channel_multiplier * unet_narrow),
|
||||
'512': int(32 * channel_multiplier * unet_narrow),
|
||||
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||
|
||||
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
||||
|
||||
# downsample
|
||||
in_channels = channels[f'{first_out_size}']
|
||||
self.conv_body_down = nn.ModuleList()
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channels = channels[f'{2**(i - 1)}']
|
||||
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
||||
in_channels = out_channels
|
||||
|
||||
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
||||
|
||||
# upsample
|
||||
in_channels = channels['4']
|
||||
self.conv_body_up = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
||||
in_channels = out_channels
|
||||
|
||||
# to RGB
|
||||
self.toRGB = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
||||
|
||||
if different_w:
|
||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||
else:
|
||||
linear_out_channel = num_style_feat
|
||||
|
||||
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
if sft_half:
|
||||
sft_out_channels = out_channels
|
||||
else:
|
||||
sft_out_channels = out_channels * 2
|
||||
self.condition_scale.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
self.condition_shift.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
||||
"""Forward function for GFPGANv1Clean.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
|
||||
# encoder
|
||||
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
||||
for i in range(self.log_size - 2):
|
||||
feat = self.conv_body_down[i](feat)
|
||||
unet_skips.insert(0, feat)
|
||||
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
||||
|
||||
# style code
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
conditions.append(shift.clone())
|
||||
# generate rgb images
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
return_latents=return_latents,
|
||||
input_is_latent=self.input_is_latent,
|
||||
randomize_noise=randomize_noise)
|
||||
|
||||
return image, out_rgbs
|
759
iopaint/plugins/gfpgan/archs/restoreformer_arch.py
Normal file
759
iopaint/plugins/gfpgan/archs/restoreformer_arch.py
Normal file
@ -0,0 +1,759 @@
|
||||
"""Modified from https://github.com/wzhouxiff/RestoreFormer"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
||||
____________________________________________
|
||||
Discretization bottleneck part of the VQ-VAE.
|
||||
Inputs:
|
||||
- n_e : number of embeddings
|
||||
- e_dim : dimension of embedding
|
||||
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
_____________________________________________
|
||||
"""
|
||||
|
||||
def __init__(self, n_e, e_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Inputs the output of the encoder network z and maps it to a discrete
|
||||
one-hot vector that is the index of the closest embedding vector e_j
|
||||
z (continuous) -> z_q (discrete)
|
||||
z.shape = (batch, channel, height, width)
|
||||
quantization pipeline:
|
||||
1. get encoder input (B,C,H,W)
|
||||
2. flatten input to (B*H*W,C)
|
||||
"""
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
)
|
||||
|
||||
# could possible replace this here
|
||||
# #\start...
|
||||
# find closest encodings
|
||||
|
||||
min_value, min_encoding_indices = torch.min(d, dim=1)
|
||||
|
||||
min_encoding_indices = min_encoding_indices.unsqueeze(1)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# dtype min encodings: torch.float32
|
||||
# min_encodings shape: torch.Size([2048, 512])
|
||||
# min_encoding_indices.shape: torch.Size([2048, 1])
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# .........\end
|
||||
|
||||
# with:
|
||||
# .........\start
|
||||
# min_encoding_indices = torch.argmin(d, dim=1)
|
||||
# z_q = self.embedding(min_encoding_indices)
|
||||
# ......\end......... (TODO)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
# TODO: check for more easy handling with nn.Embedding
|
||||
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
||||
min_encodings.scatter_(1, indices[:, None], 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class MultiHeadAttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, head_size=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.head_size = head_size
|
||||
self.att_size = in_channels // head_size
|
||||
assert (
|
||||
in_channels % head_size == 0
|
||||
), "The size of head should be divided by the number of channels."
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.norm2 = Normalize(in_channels)
|
||||
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.num = 0
|
||||
|
||||
def forward(self, x, y=None):
|
||||
h_ = x
|
||||
h_ = self.norm1(h_)
|
||||
if y is None:
|
||||
y = h_
|
||||
else:
|
||||
y = self.norm2(y)
|
||||
|
||||
q = self.q(y)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, self.head_size, self.att_size, h * w)
|
||||
q = q.permute(0, 3, 1, 2) # b, hw, head, att
|
||||
|
||||
k = k.reshape(b, self.head_size, self.att_size, h * w)
|
||||
k = k.permute(0, 3, 1, 2)
|
||||
|
||||
v = v.reshape(b, self.head_size, self.att_size, h * w)
|
||||
v = v.permute(0, 3, 1, 2)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
k = k.transpose(1, 2).transpose(2, 3)
|
||||
|
||||
scale = int(self.att_size) ** (-0.5)
|
||||
q.mul_(scale)
|
||||
w_ = torch.matmul(q, k)
|
||||
w_ = F.softmax(w_, dim=3)
|
||||
|
||||
w_ = w_.matmul(v)
|
||||
|
||||
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
|
||||
w_ = w_.view(b, h, w, -1)
|
||||
w_ = w_.permute(0, 3, 1, 2)
|
||||
|
||||
w_ = self.proj_out(w_)
|
||||
|
||||
return x + w_
|
||||
|
||||
|
||||
class MultiHeadEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16,),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
double_z=True,
|
||||
enable_mid=True,
|
||||
head_size=1,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.enable_mid = enable_mid
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
hs = {}
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
hs["in"] = h
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
# hs.append(h)
|
||||
hs["block_" + str(i_level)] = h
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
# h = hs[-1]
|
||||
if self.enable_mid:
|
||||
h = self.mid.block_1(h, temb)
|
||||
hs["block_" + str(i_level) + "_atten"] = h
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
hs["mid_atten"] = h
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
# hs.append(h)
|
||||
hs["out"] = h
|
||||
|
||||
return hs
|
||||
|
||||
|
||||
class MultiHeadDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16,),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
give_pre_end=False,
|
||||
enable_mid=True,
|
||||
head_size=1,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.enable_mid = enable_mid
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class MultiHeadDecoderTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16,),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
give_pre_end=False,
|
||||
enable_mid=True,
|
||||
head_size=1,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.enable_mid = enable_mid
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, z, hs):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
# self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h, hs["mid_atten"])
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](
|
||||
h, hs["block_" + str(i_level) + "_atten"]
|
||||
)
|
||||
# hfeature = h.clone()
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class RestoreFormer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embed=1024,
|
||||
embed_dim=256,
|
||||
ch=64,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 2, 2, 4, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16,),
|
||||
dropout=0.0,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
double_z=False,
|
||||
enable_mid=True,
|
||||
fix_decoder=False,
|
||||
fix_codebook=True,
|
||||
fix_encoder=False,
|
||||
head_size=8,
|
||||
):
|
||||
super(RestoreFormer, self).__init__()
|
||||
|
||||
self.encoder = MultiHeadEncoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
double_z=double_z,
|
||||
enable_mid=enable_mid,
|
||||
head_size=head_size,
|
||||
)
|
||||
self.decoder = MultiHeadDecoderTransformer(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
enable_mid=enable_mid,
|
||||
head_size=head_size,
|
||||
)
|
||||
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
|
||||
if fix_decoder:
|
||||
for _, param in self.decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
for _, param in self.post_quant_conv.named_parameters():
|
||||
param.requires_grad = False
|
||||
for _, param in self.quantize.named_parameters():
|
||||
param.requires_grad = False
|
||||
elif fix_codebook:
|
||||
for _, param in self.quantize.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if fix_encoder:
|
||||
for _, param in self.encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def encode(self, x):
|
||||
hs = self.encoder(x)
|
||||
h = self.quant_conv(hs["out"])
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info, hs
|
||||
|
||||
def decode(self, quant, hs):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant, hs)
|
||||
|
||||
return dec
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
quant, diff, info, hs = self.encode(input)
|
||||
dec = self.decode(quant, hs)
|
||||
|
||||
return dec, None
|
434
iopaint/plugins/gfpgan/archs/stylegan2_clean_arch.py
Normal file
434
iopaint/plugins/gfpgan/archs/stylegan2_clean_arch.py
Normal file
@ -0,0 +1,434 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from iopaint.plugins.basicsr.arch_util import default_init_weights
|
||||
|
||||
|
||||
class NormStyleCode(nn.Module):
|
||||
def forward(self, x):
|
||||
"""Normalize the style codes.
|
||||
|
||||
Args:
|
||||
x (Tensor): Style codes with shape (b, c).
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
"""Modulated Conv2d used in StyleGAN2.
|
||||
|
||||
There is no bias in ModulatedConv2d.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
eps=1e-8,
|
||||
):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.demodulate = demodulate
|
||||
self.sample_mode = sample_mode
|
||||
self.eps = eps
|
||||
|
||||
# modulation inside each modulated conv
|
||||
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
||||
# initialization
|
||||
default_init_weights(
|
||||
self.modulation,
|
||||
scale=1,
|
||||
bias_fill=1,
|
||||
a=0,
|
||||
mode="fan_in",
|
||||
nonlinearity="linear",
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
|
||||
/ math.sqrt(in_channels * kernel_size**2)
|
||||
)
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
def forward(self, x, style):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
|
||||
Returns:
|
||||
Tensor: Modulated tensor after convolution.
|
||||
"""
|
||||
b, c, h, w = x.shape # c = c_in
|
||||
# weight modulation
|
||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
||||
weight = self.weight * style # (b, c_out, c_in, k, k)
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
||||
|
||||
weight = weight.view(
|
||||
b * self.out_channels, c, self.kernel_size, self.kernel_size
|
||||
)
|
||||
|
||||
# upsample or downsample if necessary
|
||||
if self.sample_mode == "upsample":
|
||||
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
|
||||
elif self.sample_mode == "downsample":
|
||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
|
||||
|
||||
b, c, h, w = x.shape
|
||||
x = x.view(1, b * c, h, w)
|
||||
# weight: (b*c_out, c_in, k, k), groups=b
|
||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
|
||||
f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
|
||||
)
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
):
|
||||
super(StyleConv, self).__init__()
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=demodulate,
|
||||
sample_mode=sample_mode,
|
||||
)
|
||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
||||
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
||||
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x, style, noise=None):
|
||||
# modulate
|
||||
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
||||
# noise injection
|
||||
if noise is None:
|
||||
b, _, h, w = out.shape
|
||||
noise = out.new_empty(b, 1, h, w).normal_()
|
||||
out = out + self.weight * noise
|
||||
# add bias
|
||||
out = out + self.bias
|
||||
# activation
|
||||
out = self.activate(out)
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB (image space) from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
upsample (bool): Whether to upsample. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_style_feat, upsample=True):
|
||||
super(ToRGB, self).__init__()
|
||||
self.upsample = upsample
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels,
|
||||
3,
|
||||
kernel_size=1,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=False,
|
||||
sample_mode=None,
|
||||
)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x, style, skip=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
skip (Tensor): Base/skip tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: RGB images.
|
||||
"""
|
||||
out = self.modulated_conv(x, style)
|
||||
out = out + self.bias
|
||||
if skip is not None:
|
||||
if self.upsample:
|
||||
skip = F.interpolate(
|
||||
skip, scale_factor=2, mode="bilinear", align_corners=False
|
||||
)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
"""Constant input.
|
||||
|
||||
Args:
|
||||
num_channel (int): Channel number of constant input.
|
||||
size (int): Spatial size of constant input.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channel, size):
|
||||
super(ConstantInput, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
||||
|
||||
def forward(self, batch):
|
||||
out = self.weight.repeat(batch, 1, 1, 1)
|
||||
return out
|
||||
|
||||
|
||||
class StyleGAN2GeneratorClean(nn.Module):
|
||||
"""Clean version of StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
|
||||
):
|
||||
super(StyleGAN2GeneratorClean, self).__init__()
|
||||
# Style MLP layers
|
||||
self.num_style_feat = num_style_feat
|
||||
style_mlp_layers = [NormStyleCode()]
|
||||
for i in range(num_mlp):
|
||||
style_mlp_layers.extend(
|
||||
[
|
||||
nn.Linear(num_style_feat, num_style_feat, bias=True),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
]
|
||||
)
|
||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
||||
# initialization
|
||||
default_init_weights(
|
||||
self.style_mlp,
|
||||
scale=1,
|
||||
bias_fill=0,
|
||||
a=0.2,
|
||||
mode="fan_in",
|
||||
nonlinearity="leaky_relu",
|
||||
)
|
||||
|
||||
# channel list
|
||||
channels = {
|
||||
"4": int(512 * narrow),
|
||||
"8": int(512 * narrow),
|
||||
"16": int(512 * narrow),
|
||||
"32": int(512 * narrow),
|
||||
"64": int(256 * channel_multiplier * narrow),
|
||||
"128": int(128 * channel_multiplier * narrow),
|
||||
"256": int(64 * channel_multiplier * narrow),
|
||||
"512": int(32 * channel_multiplier * narrow),
|
||||
"1024": int(16 * channel_multiplier * narrow),
|
||||
}
|
||||
self.channels = channels
|
||||
|
||||
self.constant_input = ConstantInput(channels["4"], size=4)
|
||||
self.style_conv1 = StyleConv(
|
||||
channels["4"],
|
||||
channels["4"],
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
)
|
||||
self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
self.num_latent = self.log_size * 2 - 2
|
||||
|
||||
self.style_convs = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channels = channels["4"]
|
||||
# noise
|
||||
for layer_idx in range(self.num_layers):
|
||||
resolution = 2 ** ((layer_idx + 5) // 2)
|
||||
shape = [1, 1, resolution, resolution]
|
||||
self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
|
||||
# style convs and to_rgbs
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f"{2 ** i}"]
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode="upsample",
|
||||
)
|
||||
)
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
)
|
||||
)
|
||||
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
||||
in_channels = out_channels
|
||||
|
||||
def make_noise(self):
|
||||
"""Make noise for noise injection."""
|
||||
device = self.constant_input.weight.device
|
||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def get_latent(self, x):
|
||||
return self.style_mlp(x)
|
||||
|
||||
def mean_latent(self, num_latent):
|
||||
latent_in = torch.randn(
|
||||
num_latent, self.num_style_feat, device=self.constant_input.weight.device
|
||||
)
|
||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
||||
return latent
|
||||
|
||||
def forward(
|
||||
self,
|
||||
styles,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False,
|
||||
):
|
||||
"""Forward function for StyleGAN2GeneratorClean.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [
|
||||
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
||||
]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(
|
||||
truncation_latent + truncation * (style - truncation_latent)
|
||||
)
|
||||
styles = style_truncation
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = (
|
||||
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||
self.style_convs[::2],
|
||||
self.style_convs[1::2],
|
||||
noise[1::2],
|
||||
noise[2::2],
|
||||
self.to_rgbs,
|
||||
):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
@ -20,11 +20,6 @@ class GFPGANPlugin(BasePlugin):
|
||||
model_path = download_model(url, model_md5)
|
||||
logger.info(f"GFPGAN model path: {model_path}")
|
||||
|
||||
import facexlib
|
||||
|
||||
if hasattr(facexlib.detection.retinaface, "device"):
|
||||
facexlib.detection.retinaface.device = device
|
||||
|
||||
# Use GFPGAN for face enhancement
|
||||
self.face_enhancer = MyGFPGANer(
|
||||
model_path=model_path,
|
||||
@ -64,11 +59,3 @@ class GFPGANPlugin(BasePlugin):
|
||||
# except Exception as error:
|
||||
# print("wrong scale input.", error)
|
||||
return bgr_output
|
||||
|
||||
def check_dep(self):
|
||||
try:
|
||||
import gfpgan
|
||||
except ImportError:
|
||||
return (
|
||||
"gfpgan is not installed, please install it first. pip install gfpgan"
|
||||
)
|
||||
|
@ -1,12 +1,16 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from gfpgan import GFPGANv1Clean, GFPGANer
|
||||
from torchvision.transforms.functional import normalize
|
||||
from torch.hub import get_dir
|
||||
|
||||
from .facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from .gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
from .basicsr.img_util import img2tensor, tensor2img
|
||||
|
||||
class MyGFPGANer(GFPGANer):
|
||||
|
||||
class MyGFPGANer:
|
||||
"""Helper for restoration with GFPGAN.
|
||||
|
||||
It will detect and crop faces, and then resize the faces to 512x512.
|
||||
@ -55,7 +59,7 @@ class MyGFPGANer(GFPGANer):
|
||||
sft_half=True,
|
||||
)
|
||||
elif arch == "RestoreFormer":
|
||||
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
||||
from .gfpgan.archs.restoreformer_arch import RestoreFormer
|
||||
|
||||
self.gfpgan = RestoreFormer()
|
||||
|
||||
@ -82,3 +86,71 @@ class MyGFPGANer(GFPGANer):
|
||||
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
||||
self.gfpgan.eval()
|
||||
self.gfpgan = self.gfpgan.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(
|
||||
self,
|
||||
img,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
weight=0.5,
|
||||
):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned: # the inputs are already aligned
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
self.face_helper.read_image(img)
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(
|
||||
only_center_face=only_center_face, eye_dist_threshold=5
|
||||
)
|
||||
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
# face restoration
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(
|
||||
cropped_face / 255.0, bgr2rgb=True, float32=True
|
||||
)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||
|
||||
try:
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(
|
||||
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
|
||||
)
|
||||
except RuntimeError as error:
|
||||
print(f"\tFailed inference for GFPGAN: {error}.")
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype("uint8")
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
# upsample the background
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
restored_img = self.face_helper.paste_faces_to_input_image(
|
||||
upsample_img=bg_img
|
||||
)
|
||||
return (
|
||||
self.face_helper.cropped_faces,
|
||||
self.face_helper.restored_faces,
|
||||
restored_img,
|
||||
)
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
||||
|
@ -466,6 +466,3 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
# 输出是 BGR
|
||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||
return upsampled
|
||||
|
||||
def check_dep(self):
|
||||
pass
|
||||
|
@ -20,11 +20,6 @@ class RestoreFormerPlugin(BasePlugin):
|
||||
model_path = download_model(url, model_md5)
|
||||
logger.info(f"RestoreFormer model path: {model_path}")
|
||||
|
||||
import facexlib
|
||||
|
||||
if hasattr(facexlib.detection.retinaface, "device"):
|
||||
facexlib.detection.retinaface.device = device
|
||||
|
||||
self.face_enhancer = MyGFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
@ -47,11 +42,3 @@ class RestoreFormerPlugin(BasePlugin):
|
||||
)
|
||||
logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
|
||||
return bgr_output
|
||||
|
||||
def check_dep(self):
|
||||
try:
|
||||
import gfpgan
|
||||
except ImportError:
|
||||
return (
|
||||
"gfpgan is not installed, please install it first. pip install gfpgan"
|
||||
)
|
||||
|
@ -30,7 +30,6 @@ _CANDIDATES = [
|
||||
"accelerate",
|
||||
"iopaint",
|
||||
"rembg",
|
||||
"gfpgan",
|
||||
]
|
||||
# Check once at runtime
|
||||
for name in _CANDIDATES:
|
||||
|
@ -26,6 +26,11 @@ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
||||
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
|
||||
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
|
||||
|
||||
person_p = current_dir / "image.png"
|
||||
person_bgr_img = cv2.imread(str(person_p))
|
||||
person_rgb_img = cv2.cvtColor(person_bgr_img, cv2.COLOR_BGR2RGB)
|
||||
person_rgb_img = cv2.resize(person_rgb_img, (512, 512))
|
||||
|
||||
|
||||
def _save(img, name):
|
||||
cv2.imwrite(str(save_dir / name), img)
|
||||
@ -86,7 +91,7 @@ def test_gfpgan(device):
|
||||
check_device(device)
|
||||
model = GFPGANPlugin(device)
|
||||
res = model.gen_image(
|
||||
rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
|
||||
person_rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
|
||||
)
|
||||
_save(res, f"test_gfpgan_{device}.png")
|
||||
|
||||
@ -96,7 +101,8 @@ def test_restoreformer(device):
|
||||
check_device(device)
|
||||
model = RestoreFormerPlugin(device)
|
||||
res = model.gen_image(
|
||||
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
|
||||
person_rgb_img,
|
||||
RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64),
|
||||
)
|
||||
_save(res, f"test_restoreformer_{device}.png")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user