remove gfpgan dep
This commit is contained in:
parent
ffdf5e06e1
commit
60b1411d6b
@ -8,4 +8,3 @@ def install(package):
|
|||||||
|
|
||||||
def install_plugins_package():
|
def install_plugins_package():
|
||||||
install("rembg")
|
install("rembg")
|
||||||
install("gfpgan")
|
|
||||||
|
172
iopaint/plugins/basicsr/img_util.py
Normal file
172
iopaint/plugins/basicsr/img_util.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
|
||||||
|
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||||
|
"""Numpy array to tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (list[ndarray] | ndarray): Input images.
|
||||||
|
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||||
|
float32 (bool): Whether to change to float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tensor] | tensor: Tensor images. If returned results only have
|
||||||
|
one element, just return tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _totensor(img, bgr2rgb, float32):
|
||||||
|
if img.shape[2] == 3 and bgr2rgb:
|
||||||
|
if img.dtype == 'float64':
|
||||||
|
img = img.astype('float32')
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||||
|
if float32:
|
||||||
|
img = img.float()
|
||||||
|
return img
|
||||||
|
|
||||||
|
if isinstance(imgs, list):
|
||||||
|
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||||
|
else:
|
||||||
|
return _totensor(imgs, bgr2rgb, float32)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
||||||
|
"""Convert torch Tensors into image numpy arrays.
|
||||||
|
|
||||||
|
After clamping to [min, max], values will be normalized to [0, 1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor or list[Tensor]): Accept shapes:
|
||||||
|
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
||||||
|
2) 3D Tensor of shape (3/1 x H x W);
|
||||||
|
3) 2D Tensor of shape (H x W).
|
||||||
|
Tensor channel should be in RGB order.
|
||||||
|
rgb2bgr (bool): Whether to change rgb to bgr.
|
||||||
|
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
||||||
|
to uint8 type with range [0, 255]; otherwise, float type with
|
||||||
|
range [0, 1]. Default: ``np.uint8``.
|
||||||
|
min_max (tuple[int]): min and max values for clamp.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
||||||
|
shape (H x W). The channel order is BGR.
|
||||||
|
"""
|
||||||
|
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||||
|
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||||
|
|
||||||
|
if torch.is_tensor(tensor):
|
||||||
|
tensor = [tensor]
|
||||||
|
result = []
|
||||||
|
for _tensor in tensor:
|
||||||
|
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||||
|
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||||
|
|
||||||
|
n_dim = _tensor.dim()
|
||||||
|
if n_dim == 4:
|
||||||
|
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
||||||
|
img_np = img_np.transpose(1, 2, 0)
|
||||||
|
if rgb2bgr:
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||||
|
elif n_dim == 3:
|
||||||
|
img_np = _tensor.numpy()
|
||||||
|
img_np = img_np.transpose(1, 2, 0)
|
||||||
|
if img_np.shape[2] == 1: # gray image
|
||||||
|
img_np = np.squeeze(img_np, axis=2)
|
||||||
|
else:
|
||||||
|
if rgb2bgr:
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||||
|
elif n_dim == 2:
|
||||||
|
img_np = _tensor.numpy()
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
||||||
|
if out_type == np.uint8:
|
||||||
|
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||||
|
img_np = (img_np * 255.0).round()
|
||||||
|
img_np = img_np.astype(out_type)
|
||||||
|
result.append(img_np)
|
||||||
|
if len(result) == 1:
|
||||||
|
result = result[0]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
||||||
|
"""This implementation is slightly faster than tensor2img.
|
||||||
|
It now only supports torch tensor with shape (1, c, h, w).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
||||||
|
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
||||||
|
min_max (tuple[int]): min and max values for clamp.
|
||||||
|
"""
|
||||||
|
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
||||||
|
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
||||||
|
output = output.type(torch.uint8).cpu().numpy()
|
||||||
|
if rgb2bgr:
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def imfrombytes(content, flag='color', float32=False):
|
||||||
|
"""Read an image from bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (bytes): Image bytes got from files or other streams.
|
||||||
|
flag (str): Flags specifying the color type of a loaded image,
|
||||||
|
candidates are `color`, `grayscale` and `unchanged`.
|
||||||
|
float32 (bool): Whether to change to float32., If True, will also norm
|
||||||
|
to [0, 1]. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: Loaded image array.
|
||||||
|
"""
|
||||||
|
img_np = np.frombuffer(content, np.uint8)
|
||||||
|
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
||||||
|
img = cv2.imdecode(img_np, imread_flags[flag])
|
||||||
|
if float32:
|
||||||
|
img = img.astype(np.float32) / 255.
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||||
|
"""Write image to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): Image array to be written.
|
||||||
|
file_path (str): Image file path.
|
||||||
|
params (None or list): Same as opencv's :func:`imwrite` interface.
|
||||||
|
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
||||||
|
whether to create it automatically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Successful or not.
|
||||||
|
"""
|
||||||
|
if auto_mkdir:
|
||||||
|
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||||
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
ok = cv2.imwrite(file_path, img, params)
|
||||||
|
if not ok:
|
||||||
|
raise IOError('Failed in writing images.')
|
||||||
|
|
||||||
|
|
||||||
|
def crop_border(imgs, crop_border):
|
||||||
|
"""Crop borders of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
||||||
|
crop_border (int): Crop border for each end of height and weight.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[ndarray]: Cropped images.
|
||||||
|
"""
|
||||||
|
if crop_border == 0:
|
||||||
|
return imgs
|
||||||
|
else:
|
||||||
|
if isinstance(imgs, list):
|
||||||
|
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
||||||
|
else:
|
||||||
|
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
135
iopaint/plugins/facexlib/.gitignore
vendored
Normal file
135
iopaint/plugins/facexlib/.gitignore
vendored
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
.vscode
|
||||||
|
*.pth
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
version.py
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
3
iopaint/plugins/facexlib/__init__.py
Normal file
3
iopaint/plugins/facexlib/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from .detection import *
|
||||||
|
from .utils import *
|
31
iopaint/plugins/facexlib/detection/__init__.py
Normal file
31
iopaint/plugins/facexlib/detection/__init__.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from ..utils import load_file_from_url
|
||||||
|
from .retinaface import RetinaFace
|
||||||
|
|
||||||
|
|
||||||
|
def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None):
|
||||||
|
if model_name == 'retinaface_resnet50':
|
||||||
|
model = RetinaFace(network_name='resnet50', half=half, device=device)
|
||||||
|
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
|
||||||
|
elif model_name == 'retinaface_mobile0.25':
|
||||||
|
model = RetinaFace(network_name='mobile0.25', half=half, device=device)
|
||||||
|
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'{model_name} is not implemented.')
|
||||||
|
|
||||||
|
model_path = load_file_from_url(
|
||||||
|
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
||||||
|
|
||||||
|
# TODO: clean pretrained model
|
||||||
|
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
|
||||||
|
# remove unnecessary 'module.'
|
||||||
|
for k, v in deepcopy(load_net).items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
load_net[k[7:]] = v
|
||||||
|
load_net.pop(k)
|
||||||
|
model.load_state_dict(load_net, strict=True)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
return model
|
219
iopaint/plugins/facexlib/detection/align_trans.py
Normal file
219
iopaint/plugins/facexlib/detection/align_trans.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .matlab_cp2tform import get_similarity_transform_for_cv2
|
||||||
|
|
||||||
|
# reference facial points, a list of coordinates (x,y)
|
||||||
|
REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
|
||||||
|
[33.54930115, 92.3655014], [62.72990036, 92.20410156]]
|
||||||
|
|
||||||
|
DEFAULT_CROP_SIZE = (96, 112)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceWarpException(Exception):
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'In File {}:{}'.format(__file__, super.__str__(self))
|
||||||
|
|
||||||
|
|
||||||
|
def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
get reference 5 key points according to crop settings:
|
||||||
|
0. Set default crop_size:
|
||||||
|
if default_square:
|
||||||
|
crop_size = (112, 112)
|
||||||
|
else:
|
||||||
|
crop_size = (96, 112)
|
||||||
|
1. Pad the crop_size by inner_padding_factor in each side;
|
||||||
|
2. Resize crop_size into (output_size - outer_padding*2),
|
||||||
|
pad into output_size with outer_padding;
|
||||||
|
3. Output reference_5point;
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@output_size: (w, h) or None
|
||||||
|
size of aligned face image
|
||||||
|
@inner_padding_factor: (w_factor, h_factor)
|
||||||
|
padding factor for inner (w, h)
|
||||||
|
@outer_padding: (w_pad, h_pad)
|
||||||
|
each row is a pair of coordinates (x, y)
|
||||||
|
@default_square: True or False
|
||||||
|
if True:
|
||||||
|
default crop_size = (112, 112)
|
||||||
|
else:
|
||||||
|
default crop_size = (96, 112);
|
||||||
|
!!! make sure, if output_size is not None:
|
||||||
|
(output_size - outer_padding)
|
||||||
|
= some_scale * (default crop_size * (1.0 +
|
||||||
|
inner_padding_factor))
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@reference_5point: 5x2 np.array
|
||||||
|
each row is a pair of transformed coordinates (x, y)
|
||||||
|
"""
|
||||||
|
|
||||||
|
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
||||||
|
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
||||||
|
|
||||||
|
# 0) make the inner region a square
|
||||||
|
if default_square:
|
||||||
|
size_diff = max(tmp_crop_size) - tmp_crop_size
|
||||||
|
tmp_5pts += size_diff / 2
|
||||||
|
tmp_crop_size += size_diff
|
||||||
|
|
||||||
|
if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
|
||||||
|
|
||||||
|
return tmp_5pts
|
||||||
|
|
||||||
|
if (inner_padding_factor == 0 and outer_padding == (0, 0)):
|
||||||
|
if output_size is None:
|
||||||
|
return tmp_5pts
|
||||||
|
else:
|
||||||
|
raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
||||||
|
|
||||||
|
# check output size
|
||||||
|
if not (0 <= inner_padding_factor <= 1.0):
|
||||||
|
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
||||||
|
|
||||||
|
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
|
||||||
|
output_size = tmp_crop_size * \
|
||||||
|
(1 + inner_padding_factor * 2).astype(np.int32)
|
||||||
|
output_size += np.array(outer_padding)
|
||||||
|
if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
|
||||||
|
raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
|
||||||
|
|
||||||
|
# 1) pad the inner region according inner_padding_factor
|
||||||
|
if inner_padding_factor > 0:
|
||||||
|
size_diff = tmp_crop_size * inner_padding_factor * 2
|
||||||
|
tmp_5pts += size_diff / 2
|
||||||
|
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
||||||
|
|
||||||
|
# 2) resize the padded inner region
|
||||||
|
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
||||||
|
|
||||||
|
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
||||||
|
raise FaceWarpException('Must have (output_size - outer_padding)'
|
||||||
|
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
||||||
|
|
||||||
|
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
||||||
|
tmp_5pts = tmp_5pts * scale_factor
|
||||||
|
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
||||||
|
# tmp_5pts = tmp_5pts + size_diff / 2
|
||||||
|
tmp_crop_size = size_bf_outer_pad
|
||||||
|
|
||||||
|
# 3) add outer_padding to make output_size
|
||||||
|
reference_5point = tmp_5pts + np.array(outer_padding)
|
||||||
|
tmp_crop_size = output_size
|
||||||
|
|
||||||
|
return reference_5point
|
||||||
|
|
||||||
|
|
||||||
|
def get_affine_transform_matrix(src_pts, dst_pts):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
get affine transform matrix 'tfm' from src_pts to dst_pts
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@src_pts: Kx2 np.array
|
||||||
|
source points matrix, each row is a pair of coordinates (x, y)
|
||||||
|
@dst_pts: Kx2 np.array
|
||||||
|
destination points matrix, each row is a pair of coordinates (x, y)
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@tfm: 2x3 np.array
|
||||||
|
transform matrix from src_pts to dst_pts
|
||||||
|
"""
|
||||||
|
|
||||||
|
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
||||||
|
n_pts = src_pts.shape[0]
|
||||||
|
ones = np.ones((n_pts, 1), src_pts.dtype)
|
||||||
|
src_pts_ = np.hstack([src_pts, ones])
|
||||||
|
dst_pts_ = np.hstack([dst_pts, ones])
|
||||||
|
|
||||||
|
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
||||||
|
|
||||||
|
if rank == 3:
|
||||||
|
tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
|
||||||
|
elif rank == 2:
|
||||||
|
tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
|
||||||
|
|
||||||
|
return tfm
|
||||||
|
|
||||||
|
|
||||||
|
def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
apply affine transform 'trans' to uv
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@src_img: 3x3 np.array
|
||||||
|
input image
|
||||||
|
@facial_pts: could be
|
||||||
|
1)a list of K coordinates (x,y)
|
||||||
|
or
|
||||||
|
2) Kx2 or 2xK np.array
|
||||||
|
each row or col is a pair of coordinates (x, y)
|
||||||
|
@reference_pts: could be
|
||||||
|
1) a list of K coordinates (x,y)
|
||||||
|
or
|
||||||
|
2) Kx2 or 2xK np.array
|
||||||
|
each row or col is a pair of coordinates (x, y)
|
||||||
|
or
|
||||||
|
3) None
|
||||||
|
if None, use default reference facial points
|
||||||
|
@crop_size: (w, h)
|
||||||
|
output face image size
|
||||||
|
@align_type: transform type, could be one of
|
||||||
|
1) 'similarity': use similarity transform
|
||||||
|
2) 'cv2_affine': use the first 3 points to do affine transform,
|
||||||
|
by calling cv2.getAffineTransform()
|
||||||
|
3) 'affine': use all points to do affine transform
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@face_img: output face image with size (w, h) = @crop_size
|
||||||
|
"""
|
||||||
|
|
||||||
|
if reference_pts is None:
|
||||||
|
if crop_size[0] == 96 and crop_size[1] == 112:
|
||||||
|
reference_pts = REFERENCE_FACIAL_POINTS
|
||||||
|
else:
|
||||||
|
default_square = False
|
||||||
|
inner_padding_factor = 0
|
||||||
|
outer_padding = (0, 0)
|
||||||
|
output_size = crop_size
|
||||||
|
|
||||||
|
reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
|
||||||
|
default_square)
|
||||||
|
|
||||||
|
ref_pts = np.float32(reference_pts)
|
||||||
|
ref_pts_shp = ref_pts.shape
|
||||||
|
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
||||||
|
raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
|
||||||
|
|
||||||
|
if ref_pts_shp[0] == 2:
|
||||||
|
ref_pts = ref_pts.T
|
||||||
|
|
||||||
|
src_pts = np.float32(facial_pts)
|
||||||
|
src_pts_shp = src_pts.shape
|
||||||
|
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
||||||
|
raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
|
||||||
|
|
||||||
|
if src_pts_shp[0] == 2:
|
||||||
|
src_pts = src_pts.T
|
||||||
|
|
||||||
|
if src_pts.shape != ref_pts.shape:
|
||||||
|
raise FaceWarpException('facial_pts and reference_pts must have the same shape')
|
||||||
|
|
||||||
|
if align_type == 'cv2_affine':
|
||||||
|
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
||||||
|
elif align_type == 'affine':
|
||||||
|
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
||||||
|
else:
|
||||||
|
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
|
||||||
|
|
||||||
|
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
|
||||||
|
|
||||||
|
return face_img
|
317
iopaint/plugins/facexlib/detection/matlab_cp2tform.py
Normal file
317
iopaint/plugins/facexlib/detection/matlab_cp2tform.py
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
import numpy as np
|
||||||
|
from numpy.linalg import inv, lstsq
|
||||||
|
from numpy.linalg import matrix_rank as rank
|
||||||
|
from numpy.linalg import norm
|
||||||
|
|
||||||
|
|
||||||
|
class MatlabCp2tormException(Exception):
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'In File {}:{}'.format(__file__, super.__str__(self))
|
||||||
|
|
||||||
|
|
||||||
|
def tformfwd(trans, uv):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
apply affine transform 'trans' to uv
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@trans: 3x3 np.array
|
||||||
|
transform matrix
|
||||||
|
@uv: Kx2 np.array
|
||||||
|
each row is a pair of coordinates (x, y)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@xy: Kx2 np.array
|
||||||
|
each row is a pair of transformed coordinates (x, y)
|
||||||
|
"""
|
||||||
|
uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
|
||||||
|
xy = np.dot(uv, trans)
|
||||||
|
xy = xy[:, 0:-1]
|
||||||
|
return xy
|
||||||
|
|
||||||
|
|
||||||
|
def tforminv(trans, uv):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
apply the inverse of affine transform 'trans' to uv
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@trans: 3x3 np.array
|
||||||
|
transform matrix
|
||||||
|
@uv: Kx2 np.array
|
||||||
|
each row is a pair of coordinates (x, y)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@xy: Kx2 np.array
|
||||||
|
each row is a pair of inverse-transformed coordinates (x, y)
|
||||||
|
"""
|
||||||
|
Tinv = inv(trans)
|
||||||
|
xy = tformfwd(Tinv, uv)
|
||||||
|
return xy
|
||||||
|
|
||||||
|
|
||||||
|
def findNonreflectiveSimilarity(uv, xy, options=None):
|
||||||
|
options = {'K': 2}
|
||||||
|
|
||||||
|
K = options['K']
|
||||||
|
M = xy.shape[0]
|
||||||
|
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
||||||
|
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
||||||
|
|
||||||
|
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
|
||||||
|
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
|
||||||
|
X = np.vstack((tmp1, tmp2))
|
||||||
|
|
||||||
|
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
||||||
|
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
||||||
|
U = np.vstack((u, v))
|
||||||
|
|
||||||
|
# We know that X * r = U
|
||||||
|
if rank(X) >= 2 * K:
|
||||||
|
r, _, _, _ = lstsq(X, U, rcond=-1)
|
||||||
|
r = np.squeeze(r)
|
||||||
|
else:
|
||||||
|
raise Exception('cp2tform:twoUniquePointsReq')
|
||||||
|
sc = r[0]
|
||||||
|
ss = r[1]
|
||||||
|
tx = r[2]
|
||||||
|
ty = r[3]
|
||||||
|
|
||||||
|
Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
|
||||||
|
T = inv(Tinv)
|
||||||
|
T[:, 2] = np.array([0, 0, 1])
|
||||||
|
|
||||||
|
return T, Tinv
|
||||||
|
|
||||||
|
|
||||||
|
def findSimilarity(uv, xy, options=None):
|
||||||
|
options = {'K': 2}
|
||||||
|
|
||||||
|
# uv = np.array(uv)
|
||||||
|
# xy = np.array(xy)
|
||||||
|
|
||||||
|
# Solve for trans1
|
||||||
|
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
|
||||||
|
|
||||||
|
# Solve for trans2
|
||||||
|
|
||||||
|
# manually reflect the xy data across the Y-axis
|
||||||
|
xyR = xy
|
||||||
|
xyR[:, 0] = -1 * xyR[:, 0]
|
||||||
|
|
||||||
|
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
|
||||||
|
|
||||||
|
# manually reflect the tform to undo the reflection done on xyR
|
||||||
|
TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
||||||
|
|
||||||
|
trans2 = np.dot(trans2r, TreflectY)
|
||||||
|
|
||||||
|
# Figure out if trans1 or trans2 is better
|
||||||
|
xy1 = tformfwd(trans1, uv)
|
||||||
|
norm1 = norm(xy1 - xy)
|
||||||
|
|
||||||
|
xy2 = tformfwd(trans2, uv)
|
||||||
|
norm2 = norm(xy2 - xy)
|
||||||
|
|
||||||
|
if norm1 <= norm2:
|
||||||
|
return trans1, trans1_inv
|
||||||
|
else:
|
||||||
|
trans2_inv = inv(trans2)
|
||||||
|
return trans2, trans2_inv
|
||||||
|
|
||||||
|
|
||||||
|
def get_similarity_transform(src_pts, dst_pts, reflective=True):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
Find Similarity Transform Matrix 'trans':
|
||||||
|
u = src_pts[:, 0]
|
||||||
|
v = src_pts[:, 1]
|
||||||
|
x = dst_pts[:, 0]
|
||||||
|
y = dst_pts[:, 1]
|
||||||
|
[x, y, 1] = [u, v, 1] * trans
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@src_pts: Kx2 np.array
|
||||||
|
source points, each row is a pair of coordinates (x, y)
|
||||||
|
@dst_pts: Kx2 np.array
|
||||||
|
destination points, each row is a pair of transformed
|
||||||
|
coordinates (x, y)
|
||||||
|
@reflective: True or False
|
||||||
|
if True:
|
||||||
|
use reflective similarity transform
|
||||||
|
else:
|
||||||
|
use non-reflective similarity transform
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@trans: 3x3 np.array
|
||||||
|
transform matrix from uv to xy
|
||||||
|
trans_inv: 3x3 np.array
|
||||||
|
inverse of trans, transform matrix from xy to uv
|
||||||
|
"""
|
||||||
|
|
||||||
|
if reflective:
|
||||||
|
trans, trans_inv = findSimilarity(src_pts, dst_pts)
|
||||||
|
else:
|
||||||
|
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
|
||||||
|
|
||||||
|
return trans, trans_inv
|
||||||
|
|
||||||
|
|
||||||
|
def cvt_tform_mat_for_cv2(trans):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
|
||||||
|
directly used by cv2.warpAffine():
|
||||||
|
u = src_pts[:, 0]
|
||||||
|
v = src_pts[:, 1]
|
||||||
|
x = dst_pts[:, 0]
|
||||||
|
y = dst_pts[:, 1]
|
||||||
|
[x, y].T = cv_trans * [u, v, 1].T
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@trans: 3x3 np.array
|
||||||
|
transform matrix from uv to xy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@cv2_trans: 2x3 np.array
|
||||||
|
transform matrix from src_pts to dst_pts, could be directly used
|
||||||
|
for cv2.warpAffine()
|
||||||
|
"""
|
||||||
|
cv2_trans = trans[:, 0:2].T
|
||||||
|
|
||||||
|
return cv2_trans
|
||||||
|
|
||||||
|
|
||||||
|
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
|
||||||
|
"""
|
||||||
|
Function:
|
||||||
|
----------
|
||||||
|
Find Similarity Transform Matrix 'cv2_trans' which could be
|
||||||
|
directly used by cv2.warpAffine():
|
||||||
|
u = src_pts[:, 0]
|
||||||
|
v = src_pts[:, 1]
|
||||||
|
x = dst_pts[:, 0]
|
||||||
|
y = dst_pts[:, 1]
|
||||||
|
[x, y].T = cv_trans * [u, v, 1].T
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
@src_pts: Kx2 np.array
|
||||||
|
source points, each row is a pair of coordinates (x, y)
|
||||||
|
@dst_pts: Kx2 np.array
|
||||||
|
destination points, each row is a pair of transformed
|
||||||
|
coordinates (x, y)
|
||||||
|
reflective: True or False
|
||||||
|
if True:
|
||||||
|
use reflective similarity transform
|
||||||
|
else:
|
||||||
|
use non-reflective similarity transform
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
----------
|
||||||
|
@cv2_trans: 2x3 np.array
|
||||||
|
transform matrix from src_pts to dst_pts, could be directly used
|
||||||
|
for cv2.warpAffine()
|
||||||
|
"""
|
||||||
|
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
|
||||||
|
cv2_trans = cvt_tform_mat_for_cv2(trans)
|
||||||
|
|
||||||
|
return cv2_trans
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
"""
|
||||||
|
u = [0, 6, -2]
|
||||||
|
v = [0, 3, 5]
|
||||||
|
x = [-1, 0, 4]
|
||||||
|
y = [-1, -10, 4]
|
||||||
|
|
||||||
|
# In Matlab, run:
|
||||||
|
#
|
||||||
|
# uv = [u'; v'];
|
||||||
|
# xy = [x'; y'];
|
||||||
|
# tform_sim=cp2tform(uv,xy,'similarity');
|
||||||
|
#
|
||||||
|
# trans = tform_sim.tdata.T
|
||||||
|
# ans =
|
||||||
|
# -0.0764 -1.6190 0
|
||||||
|
# 1.6190 -0.0764 0
|
||||||
|
# -3.2156 0.0290 1.0000
|
||||||
|
# trans_inv = tform_sim.tdata.Tinv
|
||||||
|
# ans =
|
||||||
|
#
|
||||||
|
# -0.0291 0.6163 0
|
||||||
|
# -0.6163 -0.0291 0
|
||||||
|
# -0.0756 1.9826 1.0000
|
||||||
|
# xy_m=tformfwd(tform_sim, u,v)
|
||||||
|
#
|
||||||
|
# xy_m =
|
||||||
|
#
|
||||||
|
# -3.2156 0.0290
|
||||||
|
# 1.1833 -9.9143
|
||||||
|
# 5.0323 2.8853
|
||||||
|
# uv_m=tforminv(tform_sim, x,y)
|
||||||
|
#
|
||||||
|
# uv_m =
|
||||||
|
#
|
||||||
|
# 0.5698 1.3953
|
||||||
|
# 6.0872 2.2733
|
||||||
|
# -2.6570 4.3314
|
||||||
|
"""
|
||||||
|
u = [0, 6, -2]
|
||||||
|
v = [0, 3, 5]
|
||||||
|
x = [-1, 0, 4]
|
||||||
|
y = [-1, -10, 4]
|
||||||
|
|
||||||
|
uv = np.array((u, v)).T
|
||||||
|
xy = np.array((x, y)).T
|
||||||
|
|
||||||
|
print('\n--->uv:')
|
||||||
|
print(uv)
|
||||||
|
print('\n--->xy:')
|
||||||
|
print(xy)
|
||||||
|
|
||||||
|
trans, trans_inv = get_similarity_transform(uv, xy)
|
||||||
|
|
||||||
|
print('\n--->trans matrix:')
|
||||||
|
print(trans)
|
||||||
|
|
||||||
|
print('\n--->trans_inv matrix:')
|
||||||
|
print(trans_inv)
|
||||||
|
|
||||||
|
print('\n---> apply transform to uv')
|
||||||
|
print('\nxy_m = uv_augmented * trans')
|
||||||
|
uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
|
||||||
|
xy_m = np.dot(uv_aug, trans)
|
||||||
|
print(xy_m)
|
||||||
|
|
||||||
|
print('\nxy_m = tformfwd(trans, uv)')
|
||||||
|
xy_m = tformfwd(trans, uv)
|
||||||
|
print(xy_m)
|
||||||
|
|
||||||
|
print('\n---> apply inverse transform to xy')
|
||||||
|
print('\nuv_m = xy_augmented * trans_inv')
|
||||||
|
xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
|
||||||
|
uv_m = np.dot(xy_aug, trans_inv)
|
||||||
|
print(uv_m)
|
||||||
|
|
||||||
|
print('\nuv_m = tformfwd(trans_inv, xy)')
|
||||||
|
uv_m = tformfwd(trans_inv, xy)
|
||||||
|
print(uv_m)
|
||||||
|
|
||||||
|
uv_m = tforminv(trans, xy)
|
||||||
|
print('\nuv_m = tforminv(trans, xy)')
|
||||||
|
print(uv_m)
|
419
iopaint/plugins/facexlib/detection/retinaface.py
Normal file
419
iopaint/plugins/facexlib/detection/retinaface.py
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
|
||||||
|
|
||||||
|
from .align_trans import get_reference_facial_points, warp_and_crop_face
|
||||||
|
from .retinaface_net import (
|
||||||
|
FPN,
|
||||||
|
SSH,
|
||||||
|
MobileNetV1,
|
||||||
|
make_bbox_head,
|
||||||
|
make_class_head,
|
||||||
|
make_landmark_head,
|
||||||
|
)
|
||||||
|
from .retinaface_utils import (
|
||||||
|
PriorBox,
|
||||||
|
batched_decode,
|
||||||
|
batched_decode_landm,
|
||||||
|
decode,
|
||||||
|
decode_landm,
|
||||||
|
py_cpu_nms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_config(network_name):
|
||||||
|
cfg_mnet = {
|
||||||
|
"name": "mobilenet0.25",
|
||||||
|
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
||||||
|
"steps": [8, 16, 32],
|
||||||
|
"variance": [0.1, 0.2],
|
||||||
|
"clip": False,
|
||||||
|
"loc_weight": 2.0,
|
||||||
|
"gpu_train": True,
|
||||||
|
"batch_size": 32,
|
||||||
|
"ngpu": 1,
|
||||||
|
"epoch": 250,
|
||||||
|
"decay1": 190,
|
||||||
|
"decay2": 220,
|
||||||
|
"image_size": 640,
|
||||||
|
"return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
|
||||||
|
"in_channel": 32,
|
||||||
|
"out_channel": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg_re50 = {
|
||||||
|
"name": "Resnet50",
|
||||||
|
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
||||||
|
"steps": [8, 16, 32],
|
||||||
|
"variance": [0.1, 0.2],
|
||||||
|
"clip": False,
|
||||||
|
"loc_weight": 2.0,
|
||||||
|
"gpu_train": True,
|
||||||
|
"batch_size": 24,
|
||||||
|
"ngpu": 4,
|
||||||
|
"epoch": 100,
|
||||||
|
"decay1": 70,
|
||||||
|
"decay2": 90,
|
||||||
|
"image_size": 840,
|
||||||
|
"return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
|
||||||
|
"in_channel": 256,
|
||||||
|
"out_channel": 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
if network_name == "mobile0.25":
|
||||||
|
return cfg_mnet
|
||||||
|
elif network_name == "resnet50":
|
||||||
|
return cfg_re50
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"network_name={network_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class RetinaFace(nn.Module):
|
||||||
|
def __init__(self, network_name="resnet50", half=False, phase="test", device=None):
|
||||||
|
self.device = (
|
||||||
|
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
if device is None
|
||||||
|
else device
|
||||||
|
)
|
||||||
|
|
||||||
|
super(RetinaFace, self).__init__()
|
||||||
|
self.half_inference = half
|
||||||
|
cfg = generate_config(network_name)
|
||||||
|
self.backbone = cfg["name"]
|
||||||
|
|
||||||
|
self.model_name = f"retinaface_{network_name}"
|
||||||
|
self.cfg = cfg
|
||||||
|
self.phase = phase
|
||||||
|
self.target_size, self.max_size = 1600, 2150
|
||||||
|
self.resize, self.scale, self.scale1 = 1.0, None, None
|
||||||
|
self.mean_tensor = torch.tensor(
|
||||||
|
[[[[104.0]], [[117.0]], [[123.0]]]], device=self.device
|
||||||
|
)
|
||||||
|
self.reference = get_reference_facial_points(default_square=True)
|
||||||
|
# Build network.
|
||||||
|
backbone = None
|
||||||
|
if cfg["name"] == "mobilenet0.25":
|
||||||
|
backbone = MobileNetV1()
|
||||||
|
self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
|
||||||
|
elif cfg["name"] == "Resnet50":
|
||||||
|
import torchvision.models as models
|
||||||
|
|
||||||
|
backbone = models.resnet50(pretrained=False)
|
||||||
|
self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
|
||||||
|
|
||||||
|
in_channels_stage2 = cfg["in_channel"]
|
||||||
|
in_channels_list = [
|
||||||
|
in_channels_stage2 * 2,
|
||||||
|
in_channels_stage2 * 4,
|
||||||
|
in_channels_stage2 * 8,
|
||||||
|
]
|
||||||
|
|
||||||
|
out_channels = cfg["out_channel"]
|
||||||
|
self.fpn = FPN(in_channels_list, out_channels)
|
||||||
|
self.ssh1 = SSH(out_channels, out_channels)
|
||||||
|
self.ssh2 = SSH(out_channels, out_channels)
|
||||||
|
self.ssh3 = SSH(out_channels, out_channels)
|
||||||
|
|
||||||
|
self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"])
|
||||||
|
self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"])
|
||||||
|
self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"])
|
||||||
|
|
||||||
|
self.to(self.device)
|
||||||
|
self.eval()
|
||||||
|
if self.half_inference:
|
||||||
|
self.half()
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
out = self.body(inputs)
|
||||||
|
|
||||||
|
if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50":
|
||||||
|
out = list(out.values())
|
||||||
|
# FPN
|
||||||
|
fpn = self.fpn(out)
|
||||||
|
|
||||||
|
# SSH
|
||||||
|
feature1 = self.ssh1(fpn[0])
|
||||||
|
feature2 = self.ssh2(fpn[1])
|
||||||
|
feature3 = self.ssh3(fpn[2])
|
||||||
|
features = [feature1, feature2, feature3]
|
||||||
|
|
||||||
|
bbox_regressions = torch.cat(
|
||||||
|
[self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
|
||||||
|
)
|
||||||
|
classifications = torch.cat(
|
||||||
|
[self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
|
||||||
|
)
|
||||||
|
tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
|
||||||
|
ldm_regressions = torch.cat(tmp, dim=1)
|
||||||
|
|
||||||
|
if self.phase == "train":
|
||||||
|
output = (bbox_regressions, classifications, ldm_regressions)
|
||||||
|
else:
|
||||||
|
output = (
|
||||||
|
bbox_regressions,
|
||||||
|
F.softmax(classifications, dim=-1),
|
||||||
|
ldm_regressions,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __detect_faces(self, inputs):
|
||||||
|
# get scale
|
||||||
|
height, width = inputs.shape[2:]
|
||||||
|
self.scale = torch.tensor(
|
||||||
|
[width, height, width, height], dtype=torch.float32, device=self.device
|
||||||
|
)
|
||||||
|
tmp = [
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
]
|
||||||
|
self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
|
# forawrd
|
||||||
|
inputs = inputs.to(self.device)
|
||||||
|
if self.half_inference:
|
||||||
|
inputs = inputs.half()
|
||||||
|
loc, conf, landmarks = self(inputs)
|
||||||
|
|
||||||
|
# get priorbox
|
||||||
|
priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
|
||||||
|
priors = priorbox.forward().to(self.device)
|
||||||
|
|
||||||
|
return loc, conf, landmarks, priors
|
||||||
|
|
||||||
|
# single image detection
|
||||||
|
def transform(self, image, use_origin_size):
|
||||||
|
# convert to opencv format
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||||
|
image = image.astype(np.float32)
|
||||||
|
|
||||||
|
# testing scale
|
||||||
|
im_size_min = np.min(image.shape[0:2])
|
||||||
|
im_size_max = np.max(image.shape[0:2])
|
||||||
|
resize = float(self.target_size) / float(im_size_min)
|
||||||
|
|
||||||
|
# prevent bigger axis from being more than max_size
|
||||||
|
if np.round(resize * im_size_max) > self.max_size:
|
||||||
|
resize = float(self.max_size) / float(im_size_max)
|
||||||
|
resize = 1 if use_origin_size else resize
|
||||||
|
|
||||||
|
# resize
|
||||||
|
if resize != 1:
|
||||||
|
image = cv2.resize(
|
||||||
|
image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert to torch.tensor format
|
||||||
|
# image -= (104, 117, 123)
|
||||||
|
image = image.transpose(2, 0, 1)
|
||||||
|
image = torch.from_numpy(image).unsqueeze(0)
|
||||||
|
|
||||||
|
return image, resize
|
||||||
|
|
||||||
|
def detect_faces(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
conf_threshold=0.8,
|
||||||
|
nms_threshold=0.4,
|
||||||
|
use_origin_size=True,
|
||||||
|
):
|
||||||
|
image, self.resize = self.transform(image, use_origin_size)
|
||||||
|
image = image.to(self.device)
|
||||||
|
if self.half_inference:
|
||||||
|
image = image.half()
|
||||||
|
image = image - self.mean_tensor
|
||||||
|
|
||||||
|
loc, conf, landmarks, priors = self.__detect_faces(image)
|
||||||
|
|
||||||
|
boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"])
|
||||||
|
boxes = boxes * self.scale / self.resize
|
||||||
|
boxes = boxes.cpu().numpy()
|
||||||
|
|
||||||
|
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
||||||
|
|
||||||
|
landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"])
|
||||||
|
landmarks = landmarks * self.scale1 / self.resize
|
||||||
|
landmarks = landmarks.cpu().numpy()
|
||||||
|
|
||||||
|
# ignore low scores
|
||||||
|
inds = np.where(scores > conf_threshold)[0]
|
||||||
|
boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
|
||||||
|
|
||||||
|
# sort
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
|
||||||
|
|
||||||
|
# do NMS
|
||||||
|
bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(
|
||||||
|
np.float32, copy=False
|
||||||
|
)
|
||||||
|
keep = py_cpu_nms(bounding_boxes, nms_threshold)
|
||||||
|
bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
|
||||||
|
# self.t['forward_pass'].toc()
|
||||||
|
# print(self.t['forward_pass'].average_time)
|
||||||
|
# import sys
|
||||||
|
# sys.stdout.flush()
|
||||||
|
return np.concatenate((bounding_boxes, landmarks), axis=1)
|
||||||
|
|
||||||
|
def __align_multi(self, image, boxes, landmarks, limit=None):
|
||||||
|
if len(boxes) < 1:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
boxes = boxes[:limit]
|
||||||
|
landmarks = landmarks[:limit]
|
||||||
|
|
||||||
|
faces = []
|
||||||
|
for landmark in landmarks:
|
||||||
|
facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
|
||||||
|
|
||||||
|
warped_face = warp_and_crop_face(
|
||||||
|
np.array(image), facial5points, self.reference, crop_size=(112, 112)
|
||||||
|
)
|
||||||
|
faces.append(warped_face)
|
||||||
|
|
||||||
|
return np.concatenate((boxes, landmarks), axis=1), faces
|
||||||
|
|
||||||
|
def align_multi(self, img, conf_threshold=0.8, limit=None):
|
||||||
|
rlt = self.detect_faces(img, conf_threshold=conf_threshold)
|
||||||
|
boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
|
||||||
|
|
||||||
|
return self.__align_multi(img, boxes, landmarks, limit)
|
||||||
|
|
||||||
|
# batched detection
|
||||||
|
def batched_transform(self, frames, use_origin_size):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
|
||||||
|
type=np.float32, BGR format).
|
||||||
|
use_origin_size: whether to use origin size.
|
||||||
|
"""
|
||||||
|
from_PIL = True if isinstance(frames[0], Image.Image) else False
|
||||||
|
|
||||||
|
# convert to opencv format
|
||||||
|
if from_PIL:
|
||||||
|
frames = [
|
||||||
|
cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames
|
||||||
|
]
|
||||||
|
frames = np.asarray(frames, dtype=np.float32)
|
||||||
|
|
||||||
|
# testing scale
|
||||||
|
im_size_min = np.min(frames[0].shape[0:2])
|
||||||
|
im_size_max = np.max(frames[0].shape[0:2])
|
||||||
|
resize = float(self.target_size) / float(im_size_min)
|
||||||
|
|
||||||
|
# prevent bigger axis from being more than max_size
|
||||||
|
if np.round(resize * im_size_max) > self.max_size:
|
||||||
|
resize = float(self.max_size) / float(im_size_max)
|
||||||
|
resize = 1 if use_origin_size else resize
|
||||||
|
|
||||||
|
# resize
|
||||||
|
if resize != 1:
|
||||||
|
if not from_PIL:
|
||||||
|
frames = F.interpolate(frames, scale_factor=resize)
|
||||||
|
else:
|
||||||
|
frames = [
|
||||||
|
cv2.resize(
|
||||||
|
frame,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
fx=resize,
|
||||||
|
fy=resize,
|
||||||
|
interpolation=cv2.INTER_LINEAR,
|
||||||
|
)
|
||||||
|
for frame in frames
|
||||||
|
]
|
||||||
|
|
||||||
|
# convert to torch.tensor format
|
||||||
|
if not from_PIL:
|
||||||
|
frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
|
||||||
|
else:
|
||||||
|
frames = frames.transpose((0, 3, 1, 2))
|
||||||
|
frames = torch.from_numpy(frames)
|
||||||
|
|
||||||
|
return frames, resize
|
||||||
|
|
||||||
|
def batched_detect_faces(
|
||||||
|
self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
|
||||||
|
type=np.uint8, BGR format).
|
||||||
|
conf_threshold: confidence threshold.
|
||||||
|
nms_threshold: nms threshold.
|
||||||
|
use_origin_size: whether to use origin size.
|
||||||
|
Returns:
|
||||||
|
final_bounding_boxes: list of np.array ([n_boxes, 5],
|
||||||
|
type=np.float32).
|
||||||
|
final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
|
||||||
|
"""
|
||||||
|
# self.t['forward_pass'].tic()
|
||||||
|
frames, self.resize = self.batched_transform(frames, use_origin_size)
|
||||||
|
frames = frames.to(self.device)
|
||||||
|
frames = frames - self.mean_tensor
|
||||||
|
|
||||||
|
b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
|
||||||
|
|
||||||
|
final_bounding_boxes, final_landmarks = [], []
|
||||||
|
|
||||||
|
# decode
|
||||||
|
priors = priors.unsqueeze(0)
|
||||||
|
b_loc = (
|
||||||
|
batched_decode(b_loc, priors, self.cfg["variance"])
|
||||||
|
* self.scale
|
||||||
|
/ self.resize
|
||||||
|
)
|
||||||
|
b_landmarks = (
|
||||||
|
batched_decode_landm(b_landmarks, priors, self.cfg["variance"])
|
||||||
|
* self.scale1
|
||||||
|
/ self.resize
|
||||||
|
)
|
||||||
|
b_conf = b_conf[:, :, 1]
|
||||||
|
|
||||||
|
# index for selection
|
||||||
|
b_indice = b_conf > conf_threshold
|
||||||
|
|
||||||
|
# concat
|
||||||
|
b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
|
||||||
|
|
||||||
|
for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
|
||||||
|
# ignore low scores
|
||||||
|
pred, landm = pred[inds, :], landm[inds, :]
|
||||||
|
if pred.shape[0] == 0:
|
||||||
|
final_bounding_boxes.append(np.array([], dtype=np.float32))
|
||||||
|
final_landmarks.append(np.array([], dtype=np.float32))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# sort
|
||||||
|
# order = score.argsort(descending=True)
|
||||||
|
# box, landm, score = box[order], landm[order], score[order]
|
||||||
|
|
||||||
|
# to CPU
|
||||||
|
bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
|
||||||
|
|
||||||
|
# NMS
|
||||||
|
keep = py_cpu_nms(bounding_boxes, nms_threshold)
|
||||||
|
bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
|
||||||
|
|
||||||
|
# append
|
||||||
|
final_bounding_boxes.append(bounding_boxes)
|
||||||
|
final_landmarks.append(landmarks)
|
||||||
|
# self.t['forward_pass'].toc(average=True)
|
||||||
|
# self.batch_time += self.t['forward_pass'].diff
|
||||||
|
# self.total_frame += len(frames)
|
||||||
|
# print(self.batch_time / self.total_frame)
|
||||||
|
|
||||||
|
return final_bounding_boxes, final_landmarks
|
196
iopaint/plugins/facexlib/detection/retinaface_net.py
Normal file
196
iopaint/plugins/facexlib/detection/retinaface_net.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn(inp, oup, stride=1, leaky=0):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
|
||||||
|
nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn_no_relu(inp, oup, stride):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(oup),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn1X1(inp, oup, stride, leaky=0):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
|
||||||
|
nn.LeakyReLU(negative_slope=leaky, inplace=True))
|
||||||
|
|
||||||
|
|
||||||
|
def conv_dw(inp, oup, stride, leaky=0.1):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
||||||
|
nn.BatchNorm2d(inp),
|
||||||
|
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||||
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(oup),
|
||||||
|
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SSH(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channel, out_channel):
|
||||||
|
super(SSH, self).__init__()
|
||||||
|
assert out_channel % 4 == 0
|
||||||
|
leaky = 0
|
||||||
|
if (out_channel <= 64):
|
||||||
|
leaky = 0.1
|
||||||
|
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
|
||||||
|
|
||||||
|
self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
|
||||||
|
self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
|
||||||
|
|
||||||
|
self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
|
||||||
|
self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
conv3X3 = self.conv3X3(input)
|
||||||
|
|
||||||
|
conv5X5_1 = self.conv5X5_1(input)
|
||||||
|
conv5X5 = self.conv5X5_2(conv5X5_1)
|
||||||
|
|
||||||
|
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
||||||
|
conv7X7 = self.conv7x7_3(conv7X7_2)
|
||||||
|
|
||||||
|
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
||||||
|
out = F.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FPN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels_list, out_channels):
|
||||||
|
super(FPN, self).__init__()
|
||||||
|
leaky = 0
|
||||||
|
if (out_channels <= 64):
|
||||||
|
leaky = 0.1
|
||||||
|
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
|
||||||
|
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
|
||||||
|
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
|
||||||
|
|
||||||
|
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||||
|
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
# names = list(input.keys())
|
||||||
|
# input = list(input.values())
|
||||||
|
|
||||||
|
output1 = self.output1(input[0])
|
||||||
|
output2 = self.output2(input[1])
|
||||||
|
output3 = self.output3(input[2])
|
||||||
|
|
||||||
|
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
|
||||||
|
output2 = output2 + up3
|
||||||
|
output2 = self.merge2(output2)
|
||||||
|
|
||||||
|
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
|
||||||
|
output1 = output1 + up2
|
||||||
|
output1 = self.merge1(output1)
|
||||||
|
|
||||||
|
out = [output1, output2, output3]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV1(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MobileNetV1, self).__init__()
|
||||||
|
self.stage1 = nn.Sequential(
|
||||||
|
conv_bn(3, 8, 2, leaky=0.1), # 3
|
||||||
|
conv_dw(8, 16, 1), # 7
|
||||||
|
conv_dw(16, 32, 2), # 11
|
||||||
|
conv_dw(32, 32, 1), # 19
|
||||||
|
conv_dw(32, 64, 2), # 27
|
||||||
|
conv_dw(64, 64, 1), # 43
|
||||||
|
)
|
||||||
|
self.stage2 = nn.Sequential(
|
||||||
|
conv_dw(64, 128, 2), # 43 + 16 = 59
|
||||||
|
conv_dw(128, 128, 1), # 59 + 32 = 91
|
||||||
|
conv_dw(128, 128, 1), # 91 + 32 = 123
|
||||||
|
conv_dw(128, 128, 1), # 123 + 32 = 155
|
||||||
|
conv_dw(128, 128, 1), # 155 + 32 = 187
|
||||||
|
conv_dw(128, 128, 1), # 187 + 32 = 219
|
||||||
|
)
|
||||||
|
self.stage3 = nn.Sequential(
|
||||||
|
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
||||||
|
conv_dw(256, 256, 1), # 241 + 64 = 301
|
||||||
|
)
|
||||||
|
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(256, 1000)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.stage1(x)
|
||||||
|
x = self.stage2(x)
|
||||||
|
x = self.stage3(x)
|
||||||
|
x = self.avg(x)
|
||||||
|
# x = self.model(x)
|
||||||
|
x = x.view(-1, 256)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ClassHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, inchannels=512, num_anchors=3):
|
||||||
|
super(ClassHead, self).__init__()
|
||||||
|
self.num_anchors = num_anchors
|
||||||
|
self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1x1(x)
|
||||||
|
out = out.permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
|
return out.view(out.shape[0], -1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class BboxHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, inchannels=512, num_anchors=3):
|
||||||
|
super(BboxHead, self).__init__()
|
||||||
|
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1x1(x)
|
||||||
|
out = out.permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
|
return out.view(out.shape[0], -1, 4)
|
||||||
|
|
||||||
|
|
||||||
|
class LandmarkHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, inchannels=512, num_anchors=3):
|
||||||
|
super(LandmarkHead, self).__init__()
|
||||||
|
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1x1(x)
|
||||||
|
out = out.permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
|
return out.view(out.shape[0], -1, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
|
||||||
|
classhead = nn.ModuleList()
|
||||||
|
for i in range(fpn_num):
|
||||||
|
classhead.append(ClassHead(inchannels, anchor_num))
|
||||||
|
return classhead
|
||||||
|
|
||||||
|
|
||||||
|
def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
|
||||||
|
bboxhead = nn.ModuleList()
|
||||||
|
for i in range(fpn_num):
|
||||||
|
bboxhead.append(BboxHead(inchannels, anchor_num))
|
||||||
|
return bboxhead
|
||||||
|
|
||||||
|
|
||||||
|
def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
|
||||||
|
landmarkhead = nn.ModuleList()
|
||||||
|
for i in range(fpn_num):
|
||||||
|
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
|
||||||
|
return landmarkhead
|
421
iopaint/plugins/facexlib/detection/retinaface_utils.py
Normal file
421
iopaint/plugins/facexlib/detection/retinaface_utils.py
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from itertools import product as product
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
|
||||||
|
class PriorBox(object):
|
||||||
|
|
||||||
|
def __init__(self, cfg, image_size=None, phase='train'):
|
||||||
|
super(PriorBox, self).__init__()
|
||||||
|
self.min_sizes = cfg['min_sizes']
|
||||||
|
self.steps = cfg['steps']
|
||||||
|
self.clip = cfg['clip']
|
||||||
|
self.image_size = image_size
|
||||||
|
self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
|
||||||
|
self.name = 's'
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
anchors = []
|
||||||
|
for k, f in enumerate(self.feature_maps):
|
||||||
|
min_sizes = self.min_sizes[k]
|
||||||
|
for i, j in product(range(f[0]), range(f[1])):
|
||||||
|
for min_size in min_sizes:
|
||||||
|
s_kx = min_size / self.image_size[1]
|
||||||
|
s_ky = min_size / self.image_size[0]
|
||||||
|
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
|
||||||
|
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
|
||||||
|
for cy, cx in product(dense_cy, dense_cx):
|
||||||
|
anchors += [cx, cy, s_kx, s_ky]
|
||||||
|
|
||||||
|
# back to torch land
|
||||||
|
output = torch.Tensor(anchors).view(-1, 4)
|
||||||
|
if self.clip:
|
||||||
|
output.clamp_(max=1, min=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def py_cpu_nms(dets, thresh):
|
||||||
|
"""Pure Python NMS baseline."""
|
||||||
|
keep = torchvision.ops.nms(
|
||||||
|
boxes=torch.Tensor(dets[:, :4]),
|
||||||
|
scores=torch.Tensor(dets[:, 4]),
|
||||||
|
iou_threshold=thresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(keep)
|
||||||
|
|
||||||
|
|
||||||
|
def point_form(boxes):
|
||||||
|
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
|
||||||
|
representation for comparison to point form ground truth data.
|
||||||
|
Args:
|
||||||
|
boxes: (tensor) center-size default boxes from priorbox layers.
|
||||||
|
Return:
|
||||||
|
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
||||||
|
"""
|
||||||
|
return torch.cat(
|
||||||
|
(
|
||||||
|
boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
|
||||||
|
boxes[:, :2] + boxes[:, 2:] / 2),
|
||||||
|
1) # xmax, ymax
|
||||||
|
|
||||||
|
|
||||||
|
def center_size(boxes):
|
||||||
|
""" Convert prior_boxes to (cx, cy, w, h)
|
||||||
|
representation for comparison to center-size form ground truth data.
|
||||||
|
Args:
|
||||||
|
boxes: (tensor) point_form boxes
|
||||||
|
Return:
|
||||||
|
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
||||||
|
"""
|
||||||
|
return torch.cat(
|
||||||
|
(boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
|
||||||
|
boxes[:, 2:] - boxes[:, :2],
|
||||||
|
1) # w, h
|
||||||
|
|
||||||
|
|
||||||
|
def intersect(box_a, box_b):
|
||||||
|
""" We resize both tensors to [A,B,2] without new malloc:
|
||||||
|
[A,2] -> [A,1,2] -> [A,B,2]
|
||||||
|
[B,2] -> [1,B,2] -> [A,B,2]
|
||||||
|
Then we compute the area of intersect between box_a and box_b.
|
||||||
|
Args:
|
||||||
|
box_a: (tensor) bounding boxes, Shape: [A,4].
|
||||||
|
box_b: (tensor) bounding boxes, Shape: [B,4].
|
||||||
|
Return:
|
||||||
|
(tensor) intersection area, Shape: [A,B].
|
||||||
|
"""
|
||||||
|
A = box_a.size(0)
|
||||||
|
B = box_b.size(0)
|
||||||
|
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
||||||
|
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
||||||
|
inter = torch.clamp((max_xy - min_xy), min=0)
|
||||||
|
return inter[:, :, 0] * inter[:, :, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def jaccard(box_a, box_b):
|
||||||
|
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
||||||
|
is simply the intersection over union of two boxes. Here we operate on
|
||||||
|
ground truth boxes and default boxes.
|
||||||
|
E.g.:
|
||||||
|
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
||||||
|
Args:
|
||||||
|
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
|
||||||
|
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
|
||||||
|
Return:
|
||||||
|
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
|
||||||
|
"""
|
||||||
|
inter = intersect(box_a, box_b)
|
||||||
|
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
||||||
|
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
||||||
|
union = area_a + area_b - inter
|
||||||
|
return inter / union # [A,B]
|
||||||
|
|
||||||
|
|
||||||
|
def matrix_iou(a, b):
|
||||||
|
"""
|
||||||
|
return iou of a and b, numpy version for data augenmentation
|
||||||
|
"""
|
||||||
|
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
||||||
|
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
||||||
|
|
||||||
|
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
||||||
|
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
||||||
|
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
|
||||||
|
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
|
||||||
|
|
||||||
|
|
||||||
|
def matrix_iof(a, b):
|
||||||
|
"""
|
||||||
|
return iof of a and b, numpy version for data augenmentation
|
||||||
|
"""
|
||||||
|
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
||||||
|
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
||||||
|
|
||||||
|
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
||||||
|
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
||||||
|
return area_i / np.maximum(area_a[:, np.newaxis], 1)
|
||||||
|
|
||||||
|
|
||||||
|
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
|
||||||
|
"""Match each prior box with the ground truth box of the highest jaccard
|
||||||
|
overlap, encode the bounding boxes, then return the matched indices
|
||||||
|
corresponding to both confidence and location preds.
|
||||||
|
Args:
|
||||||
|
threshold: (float) The overlap threshold used when matching boxes.
|
||||||
|
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
|
||||||
|
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
|
||||||
|
variances: (tensor) Variances corresponding to each prior coord,
|
||||||
|
Shape: [num_priors, 4].
|
||||||
|
labels: (tensor) All the class labels for the image, Shape: [num_obj].
|
||||||
|
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
|
||||||
|
loc_t: (tensor) Tensor to be filled w/ encoded location targets.
|
||||||
|
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
|
||||||
|
landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
|
||||||
|
idx: (int) current batch index
|
||||||
|
Return:
|
||||||
|
The matched indices corresponding to 1)location 2)confidence
|
||||||
|
3)landm preds.
|
||||||
|
"""
|
||||||
|
# jaccard index
|
||||||
|
overlaps = jaccard(truths, point_form(priors))
|
||||||
|
# (Bipartite Matching)
|
||||||
|
# [1,num_objects] best prior for each ground truth
|
||||||
|
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
|
||||||
|
|
||||||
|
# ignore hard gt
|
||||||
|
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
|
||||||
|
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
|
||||||
|
if best_prior_idx_filter.shape[0] <= 0:
|
||||||
|
loc_t[idx] = 0
|
||||||
|
conf_t[idx] = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
# [1,num_priors] best ground truth for each prior
|
||||||
|
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
|
||||||
|
best_truth_idx.squeeze_(0)
|
||||||
|
best_truth_overlap.squeeze_(0)
|
||||||
|
best_prior_idx.squeeze_(1)
|
||||||
|
best_prior_idx_filter.squeeze_(1)
|
||||||
|
best_prior_overlap.squeeze_(1)
|
||||||
|
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
|
||||||
|
# TODO refactor: index best_prior_idx with long tensor
|
||||||
|
# ensure every gt matches with its prior of max overlap
|
||||||
|
for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
|
||||||
|
best_truth_idx[best_prior_idx[j]] = j
|
||||||
|
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
|
||||||
|
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
|
||||||
|
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
|
||||||
|
loc = encode(matches, priors, variances)
|
||||||
|
|
||||||
|
matches_landm = landms[best_truth_idx]
|
||||||
|
landm = encode_landm(matches_landm, priors, variances)
|
||||||
|
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
|
||||||
|
conf_t[idx] = conf # [num_priors] top class label for each prior
|
||||||
|
landm_t[idx] = landm
|
||||||
|
|
||||||
|
|
||||||
|
def encode(matched, priors, variances):
|
||||||
|
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||||
|
we have matched (based on jaccard overlap) with the prior boxes.
|
||||||
|
Args:
|
||||||
|
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||||
|
Shape: [num_priors, 4].
|
||||||
|
priors: (tensor) Prior boxes in center-offset form
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
encoded boxes (tensor), Shape: [num_priors, 4]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# dist b/t match center and prior's center
|
||||||
|
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
||||||
|
# encode variance
|
||||||
|
g_cxcy /= (variances[0] * priors[:, 2:])
|
||||||
|
# match wh / prior wh
|
||||||
|
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
||||||
|
g_wh = torch.log(g_wh) / variances[1]
|
||||||
|
# return target for smooth_l1_loss
|
||||||
|
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
||||||
|
|
||||||
|
|
||||||
|
def encode_landm(matched, priors, variances):
|
||||||
|
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||||
|
we have matched (based on jaccard overlap) with the prior boxes.
|
||||||
|
Args:
|
||||||
|
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||||
|
Shape: [num_priors, 10].
|
||||||
|
priors: (tensor) Prior boxes in center-offset form
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
encoded landm (tensor), Shape: [num_priors, 10]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# dist b/t match center and prior's center
|
||||||
|
matched = torch.reshape(matched, (matched.size(0), 5, 2))
|
||||||
|
priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||||
|
priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||||
|
priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||||
|
priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
||||||
|
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
|
||||||
|
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
|
||||||
|
# encode variance
|
||||||
|
g_cxcy /= (variances[0] * priors[:, :, 2:])
|
||||||
|
# g_cxcy /= priors[:, :, 2:]
|
||||||
|
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
|
||||||
|
# return target for smooth_l1_loss
|
||||||
|
return g_cxcy
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
||||||
|
def decode(loc, priors, variances):
|
||||||
|
"""Decode locations from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
loc (tensor): location predictions for loc layers,
|
||||||
|
Shape: [num_priors,4]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded bounding box predictions
|
||||||
|
"""
|
||||||
|
|
||||||
|
boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||||
|
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||||
|
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||||
|
boxes[:, 2:] += boxes[:, :2]
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
|
def decode_landm(pre, priors, variances):
|
||||||
|
"""Decode landm from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
pre (tensor): landm predictions for loc layers,
|
||||||
|
Shape: [num_priors,10]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded landm predictions
|
||||||
|
"""
|
||||||
|
tmp = (
|
||||||
|
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
|
||||||
|
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
|
||||||
|
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
|
||||||
|
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
|
||||||
|
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
|
||||||
|
)
|
||||||
|
landms = torch.cat(tmp, dim=1)
|
||||||
|
return landms
|
||||||
|
|
||||||
|
|
||||||
|
def batched_decode(b_loc, priors, variances):
|
||||||
|
"""Decode locations from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
b_loc (tensor): location predictions for loc layers,
|
||||||
|
Shape: [num_batches,num_priors,4]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [1,num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded bounding box predictions
|
||||||
|
"""
|
||||||
|
boxes = (
|
||||||
|
priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||||
|
priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
|
||||||
|
)
|
||||||
|
boxes = torch.cat(boxes, dim=2)
|
||||||
|
|
||||||
|
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
||||||
|
boxes[:, :, 2:] += boxes[:, :, :2]
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
|
def batched_decode_landm(pre, priors, variances):
|
||||||
|
"""Decode landm from predictions using priors to undo
|
||||||
|
the encoding we did for offset regression at train time.
|
||||||
|
Args:
|
||||||
|
pre (tensor): landm predictions for loc layers,
|
||||||
|
Shape: [num_batches,num_priors,10]
|
||||||
|
priors (tensor): Prior boxes in center-offset form.
|
||||||
|
Shape: [1,num_priors,4].
|
||||||
|
variances: (list[float]) Variances of priorboxes
|
||||||
|
Return:
|
||||||
|
decoded landm predictions
|
||||||
|
"""
|
||||||
|
landms = (
|
||||||
|
priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||||
|
priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
|
||||||
|
priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
|
||||||
|
priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
|
||||||
|
priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
|
||||||
|
)
|
||||||
|
landms = torch.cat(landms, dim=2)
|
||||||
|
return landms
|
||||||
|
|
||||||
|
|
||||||
|
def log_sum_exp(x):
|
||||||
|
"""Utility function for computing log_sum_exp while determining
|
||||||
|
This will be used to determine unaveraged confidence loss across
|
||||||
|
all examples in a batch.
|
||||||
|
Args:
|
||||||
|
x (Variable(tensor)): conf_preds from conf layers
|
||||||
|
"""
|
||||||
|
x_max = x.data.max()
|
||||||
|
return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
|
||||||
|
|
||||||
|
|
||||||
|
# Original author: Francisco Massa:
|
||||||
|
# https://github.com/fmassa/object-detection.torch
|
||||||
|
# Ported to PyTorch by Max deGroot (02/01/2017)
|
||||||
|
def nms(boxes, scores, overlap=0.5, top_k=200):
|
||||||
|
"""Apply non-maximum suppression at test time to avoid detecting too many
|
||||||
|
overlapping bounding boxes for a given object.
|
||||||
|
Args:
|
||||||
|
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
||||||
|
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
||||||
|
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
||||||
|
top_k: (int) The Maximum number of box preds to consider.
|
||||||
|
Return:
|
||||||
|
The indices of the kept boxes with respect to num_priors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keep = torch.Tensor(scores.size(0)).fill_(0).long()
|
||||||
|
if boxes.numel() == 0:
|
||||||
|
return keep
|
||||||
|
x1 = boxes[:, 0]
|
||||||
|
y1 = boxes[:, 1]
|
||||||
|
x2 = boxes[:, 2]
|
||||||
|
y2 = boxes[:, 3]
|
||||||
|
area = torch.mul(x2 - x1, y2 - y1)
|
||||||
|
v, idx = scores.sort(0) # sort in ascending order
|
||||||
|
# I = I[v >= 0.01]
|
||||||
|
idx = idx[-top_k:] # indices of the top-k largest vals
|
||||||
|
xx1 = boxes.new()
|
||||||
|
yy1 = boxes.new()
|
||||||
|
xx2 = boxes.new()
|
||||||
|
yy2 = boxes.new()
|
||||||
|
w = boxes.new()
|
||||||
|
h = boxes.new()
|
||||||
|
|
||||||
|
# keep = torch.Tensor()
|
||||||
|
count = 0
|
||||||
|
while idx.numel() > 0:
|
||||||
|
i = idx[-1] # index of current largest val
|
||||||
|
# keep.append(i)
|
||||||
|
keep[count] = i
|
||||||
|
count += 1
|
||||||
|
if idx.size(0) == 1:
|
||||||
|
break
|
||||||
|
idx = idx[:-1] # remove kept element from view
|
||||||
|
# load bboxes of next highest vals
|
||||||
|
torch.index_select(x1, 0, idx, out=xx1)
|
||||||
|
torch.index_select(y1, 0, idx, out=yy1)
|
||||||
|
torch.index_select(x2, 0, idx, out=xx2)
|
||||||
|
torch.index_select(y2, 0, idx, out=yy2)
|
||||||
|
# store element-wise max with next highest score
|
||||||
|
xx1 = torch.clamp(xx1, min=x1[i])
|
||||||
|
yy1 = torch.clamp(yy1, min=y1[i])
|
||||||
|
xx2 = torch.clamp(xx2, max=x2[i])
|
||||||
|
yy2 = torch.clamp(yy2, max=y2[i])
|
||||||
|
w.resize_as_(xx2)
|
||||||
|
h.resize_as_(yy2)
|
||||||
|
w = xx2 - xx1
|
||||||
|
h = yy2 - yy1
|
||||||
|
# check sizes of xx1 and xx2.. after each iteration
|
||||||
|
w = torch.clamp(w, min=0.0)
|
||||||
|
h = torch.clamp(h, min=0.0)
|
||||||
|
inter = w * h
|
||||||
|
# IoU = i / (area(a) + area(b) - i)
|
||||||
|
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
||||||
|
union = (rem_areas - inter) + area[i]
|
||||||
|
IoU = inter / union # store result in iou
|
||||||
|
# keep only elements with an IoU <= overlap
|
||||||
|
idx = idx[IoU.le(overlap)]
|
||||||
|
return keep, count
|
24
iopaint/plugins/facexlib/parsing/__init__.py
Normal file
24
iopaint/plugins/facexlib/parsing/__init__.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from ..utils import load_file_from_url
|
||||||
|
from .bisenet import BiSeNet
|
||||||
|
from .parsenet import ParseNet
|
||||||
|
|
||||||
|
|
||||||
|
def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None):
|
||||||
|
if model_name == 'bisenet':
|
||||||
|
model = BiSeNet(num_class=19)
|
||||||
|
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth'
|
||||||
|
elif model_name == 'parsenet':
|
||||||
|
model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
|
||||||
|
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'{model_name} is not implemented.')
|
||||||
|
|
||||||
|
model_path = load_file_from_url(
|
||||||
|
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
||||||
|
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
|
||||||
|
model.load_state_dict(load_net, strict=True)
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
return model
|
140
iopaint/plugins/facexlib/parsing/bisenet.py
Normal file
140
iopaint/plugins/facexlib/parsing/bisenet.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .resnet import ResNet18
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNReLU(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
|
||||||
|
super(ConvBNReLU, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
|
||||||
|
self.bn = nn.BatchNorm2d(out_chan)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = F.relu(self.bn(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BiSeNetOutput(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chan, mid_chan, num_class):
|
||||||
|
super(BiSeNetOutput, self).__init__()
|
||||||
|
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = self.conv(x)
|
||||||
|
out = self.conv_out(feat)
|
||||||
|
return out, feat
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionRefinementModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chan, out_chan):
|
||||||
|
super(AttentionRefinementModule, self).__init__()
|
||||||
|
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
|
||||||
|
self.bn_atten = nn.BatchNorm2d(out_chan)
|
||||||
|
self.sigmoid_atten = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = self.conv(x)
|
||||||
|
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||||
|
atten = self.conv_atten(atten)
|
||||||
|
atten = self.bn_atten(atten)
|
||||||
|
atten = self.sigmoid_atten(atten)
|
||||||
|
out = torch.mul(feat, atten)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPath(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ContextPath, self).__init__()
|
||||||
|
self.resnet = ResNet18()
|
||||||
|
self.arm16 = AttentionRefinementModule(256, 128)
|
||||||
|
self.arm32 = AttentionRefinementModule(512, 128)
|
||||||
|
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||||
|
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat8, feat16, feat32 = self.resnet(x)
|
||||||
|
h8, w8 = feat8.size()[2:]
|
||||||
|
h16, w16 = feat16.size()[2:]
|
||||||
|
h32, w32 = feat32.size()[2:]
|
||||||
|
|
||||||
|
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||||||
|
avg = self.conv_avg(avg)
|
||||||
|
avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
|
||||||
|
|
||||||
|
feat32_arm = self.arm32(feat32)
|
||||||
|
feat32_sum = feat32_arm + avg_up
|
||||||
|
feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
|
||||||
|
feat32_up = self.conv_head32(feat32_up)
|
||||||
|
|
||||||
|
feat16_arm = self.arm16(feat16)
|
||||||
|
feat16_sum = feat16_arm + feat32_up
|
||||||
|
feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
|
||||||
|
feat16_up = self.conv_head16(feat16_up)
|
||||||
|
|
||||||
|
return feat8, feat16_up, feat32_up # x8, x8, x16
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chan, out_chan):
|
||||||
|
super(FeatureFusionModule, self).__init__()
|
||||||
|
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||||||
|
self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, fsp, fcp):
|
||||||
|
fcat = torch.cat([fsp, fcp], dim=1)
|
||||||
|
feat = self.convblk(fcat)
|
||||||
|
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||||
|
atten = self.conv1(atten)
|
||||||
|
atten = self.relu(atten)
|
||||||
|
atten = self.conv2(atten)
|
||||||
|
atten = self.sigmoid(atten)
|
||||||
|
feat_atten = torch.mul(feat, atten)
|
||||||
|
feat_out = feat_atten + feat
|
||||||
|
return feat_out
|
||||||
|
|
||||||
|
|
||||||
|
class BiSeNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_class):
|
||||||
|
super(BiSeNet, self).__init__()
|
||||||
|
self.cp = ContextPath()
|
||||||
|
self.ffm = FeatureFusionModule(256, 256)
|
||||||
|
self.conv_out = BiSeNetOutput(256, 256, num_class)
|
||||||
|
self.conv_out16 = BiSeNetOutput(128, 64, num_class)
|
||||||
|
self.conv_out32 = BiSeNetOutput(128, 64, num_class)
|
||||||
|
|
||||||
|
def forward(self, x, return_feat=False):
|
||||||
|
h, w = x.size()[2:]
|
||||||
|
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
|
||||||
|
feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
|
||||||
|
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
||||||
|
|
||||||
|
out, feat = self.conv_out(feat_fuse)
|
||||||
|
out16, feat16 = self.conv_out16(feat_cp8)
|
||||||
|
out32, feat32 = self.conv_out32(feat_cp16)
|
||||||
|
|
||||||
|
out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
|
||||||
|
out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
|
||||||
|
out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
|
||||||
|
|
||||||
|
if return_feat:
|
||||||
|
feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
|
||||||
|
feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
|
||||||
|
feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
|
||||||
|
return out, out16, out32, feat, feat16, feat32
|
||||||
|
else:
|
||||||
|
return out, out16, out32
|
194
iopaint/plugins/facexlib/parsing/parsenet.py
Normal file
194
iopaint/plugins/facexlib/parsing/parsenet.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
"""Modified from https://github.com/chaofengc/PSFRGAN
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class NormLayer(nn.Module):
|
||||||
|
"""Normalization Layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: input channels, for batch norm and instance norm.
|
||||||
|
input_size: input shape without batch size, for layer norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
|
||||||
|
super(NormLayer, self).__init__()
|
||||||
|
norm_type = norm_type.lower()
|
||||||
|
self.norm_type = norm_type
|
||||||
|
if norm_type == 'bn':
|
||||||
|
self.norm = nn.BatchNorm2d(channels, affine=True)
|
||||||
|
elif norm_type == 'in':
|
||||||
|
self.norm = nn.InstanceNorm2d(channels, affine=False)
|
||||||
|
elif norm_type == 'gn':
|
||||||
|
self.norm = nn.GroupNorm(32, channels, affine=True)
|
||||||
|
elif norm_type == 'pixel':
|
||||||
|
self.norm = lambda x: F.normalize(x, p=2, dim=1)
|
||||||
|
elif norm_type == 'layer':
|
||||||
|
self.norm = nn.LayerNorm(normalize_shape)
|
||||||
|
elif norm_type == 'none':
|
||||||
|
self.norm = lambda x: x * 1.0
|
||||||
|
else:
|
||||||
|
assert 1 == 0, f'Norm type {norm_type} not support.'
|
||||||
|
|
||||||
|
def forward(self, x, ref=None):
|
||||||
|
if self.norm_type == 'spade':
|
||||||
|
return self.norm(x, ref)
|
||||||
|
else:
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ReluLayer(nn.Module):
|
||||||
|
"""Relu Layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relu type: type of relu layer, candidates are
|
||||||
|
- ReLU
|
||||||
|
- LeakyReLU: default relu slope 0.2
|
||||||
|
- PRelu
|
||||||
|
- SELU
|
||||||
|
- none: direct pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, relu_type='relu'):
|
||||||
|
super(ReluLayer, self).__init__()
|
||||||
|
relu_type = relu_type.lower()
|
||||||
|
if relu_type == 'relu':
|
||||||
|
self.func = nn.ReLU(True)
|
||||||
|
elif relu_type == 'leakyrelu':
|
||||||
|
self.func = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
elif relu_type == 'prelu':
|
||||||
|
self.func = nn.PReLU(channels)
|
||||||
|
elif relu_type == 'selu':
|
||||||
|
self.func = nn.SELU(True)
|
||||||
|
elif relu_type == 'none':
|
||||||
|
self.func = lambda x: x * 1.0
|
||||||
|
else:
|
||||||
|
assert 1 == 0, f'Relu type {relu_type} not support.'
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
scale='none',
|
||||||
|
norm_type='none',
|
||||||
|
relu_type='none',
|
||||||
|
use_pad=True,
|
||||||
|
bias=True):
|
||||||
|
super(ConvLayer, self).__init__()
|
||||||
|
self.use_pad = use_pad
|
||||||
|
self.norm_type = norm_type
|
||||||
|
if norm_type in ['bn']:
|
||||||
|
bias = False
|
||||||
|
|
||||||
|
stride = 2 if scale == 'down' else 1
|
||||||
|
|
||||||
|
self.scale_func = lambda x: x
|
||||||
|
if scale == 'up':
|
||||||
|
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
|
||||||
|
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
|
||||||
|
|
||||||
|
self.relu = ReluLayer(out_channels, relu_type)
|
||||||
|
self.norm = NormLayer(out_channels, norm_type=norm_type)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.scale_func(x)
|
||||||
|
if self.use_pad:
|
||||||
|
out = self.reflection_pad(out)
|
||||||
|
out = self.conv2d(out)
|
||||||
|
out = self.norm(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
if scale == 'none' and c_in == c_out:
|
||||||
|
self.shortcut_func = lambda x: x
|
||||||
|
else:
|
||||||
|
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
|
||||||
|
|
||||||
|
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
|
||||||
|
scale_conf = scale_config_dict[scale]
|
||||||
|
|
||||||
|
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
|
||||||
|
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = self.shortcut_func(x)
|
||||||
|
|
||||||
|
res = self.conv1(x)
|
||||||
|
res = self.conv2(res)
|
||||||
|
return identity + res
|
||||||
|
|
||||||
|
|
||||||
|
class ParseNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_size=128,
|
||||||
|
out_size=128,
|
||||||
|
min_feat_size=32,
|
||||||
|
base_ch=64,
|
||||||
|
parsing_ch=19,
|
||||||
|
res_depth=10,
|
||||||
|
relu_type='LeakyReLU',
|
||||||
|
norm_type='bn',
|
||||||
|
ch_range=[32, 256]):
|
||||||
|
super().__init__()
|
||||||
|
self.res_depth = res_depth
|
||||||
|
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
|
||||||
|
min_ch, max_ch = ch_range
|
||||||
|
|
||||||
|
ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
|
||||||
|
min_feat_size = min(in_size, min_feat_size)
|
||||||
|
|
||||||
|
down_steps = int(np.log2(in_size // min_feat_size))
|
||||||
|
up_steps = int(np.log2(out_size // min_feat_size))
|
||||||
|
|
||||||
|
# =============== define encoder-body-decoder ====================
|
||||||
|
self.encoder = []
|
||||||
|
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
|
||||||
|
head_ch = base_ch
|
||||||
|
for i in range(down_steps):
|
||||||
|
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
|
||||||
|
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
|
||||||
|
head_ch = head_ch * 2
|
||||||
|
|
||||||
|
self.body = []
|
||||||
|
for i in range(res_depth):
|
||||||
|
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
|
||||||
|
|
||||||
|
self.decoder = []
|
||||||
|
for i in range(up_steps):
|
||||||
|
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
|
||||||
|
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
|
||||||
|
head_ch = head_ch // 2
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(*self.encoder)
|
||||||
|
self.body = nn.Sequential(*self.body)
|
||||||
|
self.decoder = nn.Sequential(*self.decoder)
|
||||||
|
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
|
||||||
|
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = self.encoder(x)
|
||||||
|
x = feat + self.body(feat)
|
||||||
|
x = self.decoder(x)
|
||||||
|
out_img = self.out_img_conv(x)
|
||||||
|
out_mask = self.out_mask_conv(x)
|
||||||
|
return out_mask, out_img
|
69
iopaint/plugins/facexlib/parsing/resnet.py
Normal file
69
iopaint/plugins/facexlib/parsing/resnet.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chan, out_chan, stride=1):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_chan)
|
||||||
|
self.conv2 = conv3x3(out_chan, out_chan)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_chan)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = None
|
||||||
|
if in_chan != out_chan or stride != 1:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(out_chan),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = self.conv1(x)
|
||||||
|
residual = F.relu(self.bn1(residual))
|
||||||
|
residual = self.conv2(residual)
|
||||||
|
residual = self.bn2(residual)
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
if self.downsample is not None:
|
||||||
|
shortcut = self.downsample(x)
|
||||||
|
|
||||||
|
out = shortcut + residual
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
||||||
|
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
||||||
|
for i in range(bnum - 1):
|
||||||
|
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet18(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ResNet18, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
||||||
|
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
||||||
|
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
||||||
|
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = F.relu(self.bn1(x))
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
feat8 = self.layer2(x) # 1/8
|
||||||
|
feat16 = self.layer3(feat8) # 1/16
|
||||||
|
feat32 = self.layer4(feat16) # 1/32
|
||||||
|
return feat8, feat16, feat32
|
7
iopaint/plugins/facexlib/utils/__init__.py
Normal file
7
iopaint/plugins/facexlib/utils/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
|
||||||
|
from .misc import img2tensor, load_file_from_url, scandir
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back',
|
||||||
|
'img2tensor', 'scandir'
|
||||||
|
]
|
473
iopaint/plugins/facexlib/utils/face_restoration_helper.py
Normal file
473
iopaint/plugins/facexlib/utils/face_restoration_helper.py
Normal file
@ -0,0 +1,473 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms.functional import normalize
|
||||||
|
|
||||||
|
from ..detection import init_detection_model
|
||||||
|
from ..parsing import init_parsing_model
|
||||||
|
from ..utils.misc import img2tensor, imwrite
|
||||||
|
|
||||||
|
|
||||||
|
def get_largest_face(det_faces, h, w):
|
||||||
|
def get_location(val, length):
|
||||||
|
if val < 0:
|
||||||
|
return 0
|
||||||
|
elif val > length:
|
||||||
|
return length
|
||||||
|
else:
|
||||||
|
return val
|
||||||
|
|
||||||
|
face_areas = []
|
||||||
|
for det_face in det_faces:
|
||||||
|
left = get_location(det_face[0], w)
|
||||||
|
right = get_location(det_face[2], w)
|
||||||
|
top = get_location(det_face[1], h)
|
||||||
|
bottom = get_location(det_face[3], h)
|
||||||
|
face_area = (right - left) * (bottom - top)
|
||||||
|
face_areas.append(face_area)
|
||||||
|
largest_idx = face_areas.index(max(face_areas))
|
||||||
|
return det_faces[largest_idx], largest_idx
|
||||||
|
|
||||||
|
|
||||||
|
def get_center_face(det_faces, h=0, w=0, center=None):
|
||||||
|
if center is not None:
|
||||||
|
center = np.array(center)
|
||||||
|
else:
|
||||||
|
center = np.array([w / 2, h / 2])
|
||||||
|
center_dist = []
|
||||||
|
for det_face in det_faces:
|
||||||
|
face_center = np.array(
|
||||||
|
[(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]
|
||||||
|
)
|
||||||
|
dist = np.linalg.norm(face_center - center)
|
||||||
|
center_dist.append(dist)
|
||||||
|
center_idx = center_dist.index(min(center_dist))
|
||||||
|
return det_faces[center_idx], center_idx
|
||||||
|
|
||||||
|
|
||||||
|
class FaceRestoreHelper(object):
|
||||||
|
"""Helper for the face restoration pipeline (base class)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
upscale_factor,
|
||||||
|
face_size=512,
|
||||||
|
crop_ratio=(1, 1),
|
||||||
|
det_model="retinaface_resnet50",
|
||||||
|
save_ext="png",
|
||||||
|
template_3points=False,
|
||||||
|
pad_blur=False,
|
||||||
|
use_parse=False,
|
||||||
|
device=None,
|
||||||
|
model_rootpath=None,
|
||||||
|
):
|
||||||
|
self.template_3points = template_3points # improve robustness
|
||||||
|
self.upscale_factor = upscale_factor
|
||||||
|
# the cropped face ratio based on the square face
|
||||||
|
self.crop_ratio = crop_ratio # (h, w)
|
||||||
|
assert (
|
||||||
|
self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1
|
||||||
|
), "crop ration only supports >=1"
|
||||||
|
self.face_size = (
|
||||||
|
int(face_size * self.crop_ratio[1]),
|
||||||
|
int(face_size * self.crop_ratio[0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.template_3points:
|
||||||
|
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
||||||
|
else:
|
||||||
|
# standard 5 landmarks for FFHQ faces with 512 x 512
|
||||||
|
self.face_template = np.array(
|
||||||
|
[
|
||||||
|
[192.98138, 239.94708],
|
||||||
|
[318.90277, 240.1936],
|
||||||
|
[256.63416, 314.01935],
|
||||||
|
[201.26117, 371.41043],
|
||||||
|
[313.08905, 371.15118],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.face_template = self.face_template * (face_size / 512.0)
|
||||||
|
if self.crop_ratio[0] > 1:
|
||||||
|
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
||||||
|
if self.crop_ratio[1] > 1:
|
||||||
|
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
||||||
|
self.save_ext = save_ext
|
||||||
|
self.pad_blur = pad_blur
|
||||||
|
if self.pad_blur is True:
|
||||||
|
self.template_3points = False
|
||||||
|
|
||||||
|
self.all_landmarks_5 = []
|
||||||
|
self.det_faces = []
|
||||||
|
self.affine_matrices = []
|
||||||
|
self.inverse_affine_matrices = []
|
||||||
|
self.cropped_faces = []
|
||||||
|
self.restored_faces = []
|
||||||
|
self.pad_input_imgs = []
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# init face detection model
|
||||||
|
self.face_det = init_detection_model(
|
||||||
|
det_model, half=False, device=self.device, model_rootpath=model_rootpath
|
||||||
|
)
|
||||||
|
|
||||||
|
# init face parsing model
|
||||||
|
self.use_parse = use_parse
|
||||||
|
self.face_parse = init_parsing_model(
|
||||||
|
model_name="parsenet", device=self.device, model_rootpath=model_rootpath
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_upscale_factor(self, upscale_factor):
|
||||||
|
self.upscale_factor = upscale_factor
|
||||||
|
|
||||||
|
def read_image(self, img):
|
||||||
|
"""img can be image path or cv2 loaded image."""
|
||||||
|
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
||||||
|
if isinstance(img, str):
|
||||||
|
img = cv2.imread(img)
|
||||||
|
|
||||||
|
if np.max(img) > 256: # 16-bit image
|
||||||
|
img = img / 65535 * 255
|
||||||
|
if len(img.shape) == 2: # gray image
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||||
|
img = img[:, :, 0:3]
|
||||||
|
|
||||||
|
self.input_img = img
|
||||||
|
|
||||||
|
def get_face_landmarks_5(
|
||||||
|
self,
|
||||||
|
only_keep_largest=False,
|
||||||
|
only_center_face=False,
|
||||||
|
resize=None,
|
||||||
|
blur_ratio=0.01,
|
||||||
|
eye_dist_threshold=None,
|
||||||
|
):
|
||||||
|
if resize is None:
|
||||||
|
scale = 1
|
||||||
|
input_img = self.input_img
|
||||||
|
else:
|
||||||
|
h, w = self.input_img.shape[0:2]
|
||||||
|
scale = min(h, w) / resize
|
||||||
|
h, w = int(h / scale), int(w / scale)
|
||||||
|
input_img = cv2.resize(
|
||||||
|
self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
|
||||||
|
for bbox in bboxes:
|
||||||
|
# remove faces with too small eye distance: side faces or too small faces
|
||||||
|
eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
|
||||||
|
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.template_3points:
|
||||||
|
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
||||||
|
else:
|
||||||
|
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
||||||
|
self.all_landmarks_5.append(landmark)
|
||||||
|
self.det_faces.append(bbox[0:5])
|
||||||
|
if len(self.det_faces) == 0:
|
||||||
|
return 0
|
||||||
|
if only_keep_largest:
|
||||||
|
h, w, _ = self.input_img.shape
|
||||||
|
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
||||||
|
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
||||||
|
elif only_center_face:
|
||||||
|
h, w, _ = self.input_img.shape
|
||||||
|
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
||||||
|
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
||||||
|
|
||||||
|
# pad blurry images
|
||||||
|
if self.pad_blur:
|
||||||
|
self.pad_input_imgs = []
|
||||||
|
for landmarks in self.all_landmarks_5:
|
||||||
|
# get landmarks
|
||||||
|
eye_left = landmarks[0, :]
|
||||||
|
eye_right = landmarks[1, :]
|
||||||
|
eye_avg = (eye_left + eye_right) * 0.5
|
||||||
|
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
||||||
|
eye_to_eye = eye_right - eye_left
|
||||||
|
eye_to_mouth = mouth_avg - eye_avg
|
||||||
|
|
||||||
|
# Get the oriented crop rectangle
|
||||||
|
# x: half width of the oriented crop rectangle
|
||||||
|
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
||||||
|
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
||||||
|
# norm with the hypotenuse: get the direction
|
||||||
|
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
||||||
|
rect_scale = 1.5
|
||||||
|
x *= max(
|
||||||
|
np.hypot(*eye_to_eye) * 2.0 * rect_scale,
|
||||||
|
np.hypot(*eye_to_mouth) * 1.8 * rect_scale,
|
||||||
|
)
|
||||||
|
# y: half height of the oriented crop rectangle
|
||||||
|
y = np.flipud(x) * [-1, 1]
|
||||||
|
|
||||||
|
# c: center
|
||||||
|
c = eye_avg + eye_to_mouth * 0.1
|
||||||
|
# quad: (left_top, left_bottom, right_bottom, right_top)
|
||||||
|
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
||||||
|
# qsize: side length of the square
|
||||||
|
qsize = np.hypot(*x) * 2
|
||||||
|
border = max(int(np.rint(qsize * 0.1)), 3)
|
||||||
|
|
||||||
|
# get pad
|
||||||
|
# pad: (width_left, height_top, width_right, height_bottom)
|
||||||
|
pad = (
|
||||||
|
int(np.floor(min(quad[:, 0]))),
|
||||||
|
int(np.floor(min(quad[:, 1]))),
|
||||||
|
int(np.ceil(max(quad[:, 0]))),
|
||||||
|
int(np.ceil(max(quad[:, 1]))),
|
||||||
|
)
|
||||||
|
pad = [
|
||||||
|
max(-pad[0] + border, 1),
|
||||||
|
max(-pad[1] + border, 1),
|
||||||
|
max(pad[2] - self.input_img.shape[0] + border, 1),
|
||||||
|
max(pad[3] - self.input_img.shape[1] + border, 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
if max(pad) > 1:
|
||||||
|
# pad image
|
||||||
|
pad_img = np.pad(
|
||||||
|
self.input_img,
|
||||||
|
((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
|
||||||
|
"reflect",
|
||||||
|
)
|
||||||
|
# modify landmark coords
|
||||||
|
landmarks[:, 0] += pad[0]
|
||||||
|
landmarks[:, 1] += pad[1]
|
||||||
|
# blur pad images
|
||||||
|
h, w, _ = pad_img.shape
|
||||||
|
y, x, _ = np.ogrid[:h, :w, :1]
|
||||||
|
mask = np.maximum(
|
||||||
|
1.0
|
||||||
|
- np.minimum(
|
||||||
|
np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]
|
||||||
|
),
|
||||||
|
1.0
|
||||||
|
- np.minimum(
|
||||||
|
np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
blur = int(qsize * blur_ratio)
|
||||||
|
if blur % 2 == 0:
|
||||||
|
blur += 1
|
||||||
|
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
||||||
|
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
||||||
|
|
||||||
|
pad_img = pad_img.astype("float32")
|
||||||
|
pad_img += (blur_img - pad_img) * np.clip(
|
||||||
|
mask * 3.0 + 1.0, 0.0, 1.0
|
||||||
|
)
|
||||||
|
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(
|
||||||
|
mask, 0.0, 1.0
|
||||||
|
)
|
||||||
|
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
||||||
|
self.pad_input_imgs.append(pad_img)
|
||||||
|
else:
|
||||||
|
self.pad_input_imgs.append(np.copy(self.input_img))
|
||||||
|
|
||||||
|
return len(self.all_landmarks_5)
|
||||||
|
|
||||||
|
def align_warp_face(self, save_cropped_path=None, border_mode="constant"):
|
||||||
|
"""Align and warp faces with face template."""
|
||||||
|
if self.pad_blur:
|
||||||
|
assert (
|
||||||
|
len(self.pad_input_imgs) == len(self.all_landmarks_5)
|
||||||
|
), f"Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}"
|
||||||
|
for idx, landmark in enumerate(self.all_landmarks_5):
|
||||||
|
# use 5 landmarks to get affine matrix
|
||||||
|
# use cv2.LMEDS method for the equivalence to skimage transform
|
||||||
|
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
||||||
|
affine_matrix = cv2.estimateAffinePartial2D(
|
||||||
|
landmark, self.face_template, method=cv2.LMEDS
|
||||||
|
)[0]
|
||||||
|
self.affine_matrices.append(affine_matrix)
|
||||||
|
# warp and crop faces
|
||||||
|
if border_mode == "constant":
|
||||||
|
border_mode = cv2.BORDER_CONSTANT
|
||||||
|
elif border_mode == "reflect101":
|
||||||
|
border_mode = cv2.BORDER_REFLECT101
|
||||||
|
elif border_mode == "reflect":
|
||||||
|
border_mode = cv2.BORDER_REFLECT
|
||||||
|
if self.pad_blur:
|
||||||
|
input_img = self.pad_input_imgs[idx]
|
||||||
|
else:
|
||||||
|
input_img = self.input_img
|
||||||
|
cropped_face = cv2.warpAffine(
|
||||||
|
input_img,
|
||||||
|
affine_matrix,
|
||||||
|
self.face_size,
|
||||||
|
borderMode=border_mode,
|
||||||
|
borderValue=(135, 133, 132),
|
||||||
|
) # gray
|
||||||
|
self.cropped_faces.append(cropped_face)
|
||||||
|
# save the cropped face
|
||||||
|
if save_cropped_path is not None:
|
||||||
|
path = os.path.splitext(save_cropped_path)[0]
|
||||||
|
save_path = f"{path}_{idx:02d}.{self.save_ext}"
|
||||||
|
imwrite(cropped_face, save_path)
|
||||||
|
|
||||||
|
def get_inverse_affine(self, save_inverse_affine_path=None):
|
||||||
|
"""Get inverse affine matrix."""
|
||||||
|
for idx, affine_matrix in enumerate(self.affine_matrices):
|
||||||
|
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
||||||
|
inverse_affine *= self.upscale_factor
|
||||||
|
self.inverse_affine_matrices.append(inverse_affine)
|
||||||
|
# save inverse affine matrices
|
||||||
|
if save_inverse_affine_path is not None:
|
||||||
|
path, _ = os.path.splitext(save_inverse_affine_path)
|
||||||
|
save_path = f"{path}_{idx:02d}.pth"
|
||||||
|
torch.save(inverse_affine, save_path)
|
||||||
|
|
||||||
|
def add_restored_face(self, face):
|
||||||
|
self.restored_faces.append(face)
|
||||||
|
|
||||||
|
def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
|
||||||
|
h, w, _ = self.input_img.shape
|
||||||
|
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
||||||
|
|
||||||
|
if upsample_img is None:
|
||||||
|
# simply resize the background
|
||||||
|
upsample_img = cv2.resize(
|
||||||
|
self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
upsample_img = cv2.resize(
|
||||||
|
upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(self.restored_faces) == len(
|
||||||
|
self.inverse_affine_matrices
|
||||||
|
), "length of restored_faces and affine_matrices are different."
|
||||||
|
for restored_face, inverse_affine in zip(
|
||||||
|
self.restored_faces, self.inverse_affine_matrices
|
||||||
|
):
|
||||||
|
# Add an offset to inverse affine matrix, for more precise back alignment
|
||||||
|
if self.upscale_factor > 1:
|
||||||
|
extra_offset = 0.5 * self.upscale_factor
|
||||||
|
else:
|
||||||
|
extra_offset = 0
|
||||||
|
inverse_affine[:, 2] += extra_offset
|
||||||
|
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
||||||
|
|
||||||
|
if self.use_parse:
|
||||||
|
# inference
|
||||||
|
face_input = cv2.resize(
|
||||||
|
restored_face, (512, 512), interpolation=cv2.INTER_LINEAR
|
||||||
|
)
|
||||||
|
face_input = img2tensor(
|
||||||
|
face_input.astype("float32") / 255.0, bgr2rgb=True, float32=True
|
||||||
|
)
|
||||||
|
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
|
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = self.face_parse(face_input)[0]
|
||||||
|
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
mask = np.zeros(out.shape)
|
||||||
|
MASK_COLORMAP = [
|
||||||
|
0,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
255,
|
||||||
|
0,
|
||||||
|
255,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
for idx, color in enumerate(MASK_COLORMAP):
|
||||||
|
mask[out == idx] = color
|
||||||
|
# blur the mask
|
||||||
|
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
||||||
|
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
||||||
|
# remove the black borders
|
||||||
|
thres = 10
|
||||||
|
mask[:thres, :] = 0
|
||||||
|
mask[-thres:, :] = 0
|
||||||
|
mask[:, :thres] = 0
|
||||||
|
mask[:, -thres:] = 0
|
||||||
|
mask = mask / 255.0
|
||||||
|
|
||||||
|
mask = cv2.resize(mask, restored_face.shape[:2])
|
||||||
|
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
|
||||||
|
inv_soft_mask = mask[:, :, None]
|
||||||
|
pasted_face = inv_restored
|
||||||
|
|
||||||
|
else: # use square parse maps
|
||||||
|
mask = np.ones(self.face_size, dtype=np.float32)
|
||||||
|
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
||||||
|
# remove the black borders
|
||||||
|
inv_mask_erosion = cv2.erode(
|
||||||
|
inv_mask,
|
||||||
|
np.ones(
|
||||||
|
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)),
|
||||||
|
np.uint8,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
||||||
|
total_face_area = np.sum(inv_mask_erosion) # // 3
|
||||||
|
# compute the fusion edge based on the area of face
|
||||||
|
w_edge = int(total_face_area**0.5) // 20
|
||||||
|
erosion_radius = w_edge * 2
|
||||||
|
inv_mask_center = cv2.erode(
|
||||||
|
inv_mask_erosion,
|
||||||
|
np.ones((erosion_radius, erosion_radius), np.uint8),
|
||||||
|
)
|
||||||
|
blur_size = w_edge * 2
|
||||||
|
inv_soft_mask = cv2.GaussianBlur(
|
||||||
|
inv_mask_center, (blur_size + 1, blur_size + 1), 0
|
||||||
|
)
|
||||||
|
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
||||||
|
upsample_img = upsample_img[:, :, None]
|
||||||
|
inv_soft_mask = inv_soft_mask[:, :, None]
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4
|
||||||
|
): # alpha channel
|
||||||
|
alpha = upsample_img[:, :, 3:]
|
||||||
|
upsample_img = (
|
||||||
|
inv_soft_mask * pasted_face
|
||||||
|
+ (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
||||||
|
)
|
||||||
|
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
||||||
|
else:
|
||||||
|
upsample_img = (
|
||||||
|
inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
||||||
|
)
|
||||||
|
|
||||||
|
if np.max(upsample_img) > 256: # 16-bit image
|
||||||
|
upsample_img = upsample_img.astype(np.uint16)
|
||||||
|
else:
|
||||||
|
upsample_img = upsample_img.astype(np.uint8)
|
||||||
|
if save_path is not None:
|
||||||
|
path = os.path.splitext(save_path)[0]
|
||||||
|
save_path = f"{path}.{self.save_ext}"
|
||||||
|
imwrite(upsample_img, save_path)
|
||||||
|
return upsample_img
|
||||||
|
|
||||||
|
def clean_all(self):
|
||||||
|
self.all_landmarks_5 = []
|
||||||
|
self.restored_faces = []
|
||||||
|
self.affine_matrices = []
|
||||||
|
self.cropped_faces = []
|
||||||
|
self.inverse_affine_matrices = []
|
||||||
|
self.det_faces = []
|
||||||
|
self.pad_input_imgs = []
|
208
iopaint/plugins/facexlib/utils/face_utils.py
Normal file
208
iopaint/plugins/facexlib/utils/face_utils.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
|
||||||
|
left, top, right, bot = bbox
|
||||||
|
width = right - left
|
||||||
|
height = bot - top
|
||||||
|
|
||||||
|
if preserve_aspect:
|
||||||
|
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
|
||||||
|
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
|
||||||
|
else:
|
||||||
|
width_increase = height_increase = increase_area
|
||||||
|
left = int(left - width_increase * width)
|
||||||
|
top = int(top - height_increase * height)
|
||||||
|
right = int(right + width_increase * width)
|
||||||
|
bot = int(bot + height_increase * height)
|
||||||
|
return (left, top, right, bot)
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_bboxes(bboxes, h, w):
|
||||||
|
left = max(bboxes[0], 0)
|
||||||
|
top = max(bboxes[1], 0)
|
||||||
|
right = min(bboxes[2], w)
|
||||||
|
bottom = min(bboxes[3], h)
|
||||||
|
return (left, top, right, bottom)
|
||||||
|
|
||||||
|
|
||||||
|
def align_crop_face_landmarks(img,
|
||||||
|
landmarks,
|
||||||
|
output_size,
|
||||||
|
transform_size=None,
|
||||||
|
enable_padding=True,
|
||||||
|
return_inverse_affine=False,
|
||||||
|
shrink_ratio=(1, 1)):
|
||||||
|
"""Align and crop face with landmarks.
|
||||||
|
|
||||||
|
The output_size and transform_size are based on width. The height is
|
||||||
|
adjusted based on shrink_ratio_h/shring_ration_w.
|
||||||
|
|
||||||
|
Modified from:
|
||||||
|
https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (Numpy array): Input image.
|
||||||
|
landmarks (Numpy array): 5 or 68 or 98 landmarks.
|
||||||
|
output_size (int): Output face size.
|
||||||
|
transform_size (ing): Transform size. Usually the four time of
|
||||||
|
output_size.
|
||||||
|
enable_padding (float): Default: True.
|
||||||
|
shrink_ratio (float | tuple[float] | list[float]): Shring the whole
|
||||||
|
face for height and width (crop larger area). Default: (1, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Numpy array): Cropped face.
|
||||||
|
"""
|
||||||
|
lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
|
||||||
|
|
||||||
|
if isinstance(shrink_ratio, (float, int)):
|
||||||
|
shrink_ratio = (shrink_ratio, shrink_ratio)
|
||||||
|
if transform_size is None:
|
||||||
|
transform_size = output_size * 4
|
||||||
|
|
||||||
|
# Parse landmarks
|
||||||
|
lm = np.array(landmarks)
|
||||||
|
if lm.shape[0] == 5 and lm_type == 'retinaface_5':
|
||||||
|
eye_left = lm[0]
|
||||||
|
eye_right = lm[1]
|
||||||
|
mouth_avg = (lm[3] + lm[4]) * 0.5
|
||||||
|
elif lm.shape[0] == 5 and lm_type == 'dlib_5':
|
||||||
|
lm_eye_left = lm[2:4]
|
||||||
|
lm_eye_right = lm[0:2]
|
||||||
|
eye_left = np.mean(lm_eye_left, axis=0)
|
||||||
|
eye_right = np.mean(lm_eye_right, axis=0)
|
||||||
|
mouth_avg = lm[4]
|
||||||
|
elif lm.shape[0] == 68:
|
||||||
|
lm_eye_left = lm[36:42]
|
||||||
|
lm_eye_right = lm[42:48]
|
||||||
|
eye_left = np.mean(lm_eye_left, axis=0)
|
||||||
|
eye_right = np.mean(lm_eye_right, axis=0)
|
||||||
|
mouth_avg = (lm[48] + lm[54]) * 0.5
|
||||||
|
elif lm.shape[0] == 98:
|
||||||
|
lm_eye_left = lm[60:68]
|
||||||
|
lm_eye_right = lm[68:76]
|
||||||
|
eye_left = np.mean(lm_eye_left, axis=0)
|
||||||
|
eye_right = np.mean(lm_eye_right, axis=0)
|
||||||
|
mouth_avg = (lm[76] + lm[82]) * 0.5
|
||||||
|
|
||||||
|
eye_avg = (eye_left + eye_right) * 0.5
|
||||||
|
eye_to_eye = eye_right - eye_left
|
||||||
|
eye_to_mouth = mouth_avg - eye_avg
|
||||||
|
|
||||||
|
# Get the oriented crop rectangle
|
||||||
|
# x: half width of the oriented crop rectangle
|
||||||
|
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
||||||
|
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
||||||
|
# norm with the hypotenuse: get the direction
|
||||||
|
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
||||||
|
rect_scale = 1 # TODO: you can edit it to get larger rect
|
||||||
|
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
||||||
|
# y: half height of the oriented crop rectangle
|
||||||
|
y = np.flipud(x) * [-1, 1]
|
||||||
|
|
||||||
|
x *= shrink_ratio[1] # width
|
||||||
|
y *= shrink_ratio[0] # height
|
||||||
|
|
||||||
|
# c: center
|
||||||
|
c = eye_avg + eye_to_mouth * 0.1
|
||||||
|
# quad: (left_top, left_bottom, right_bottom, right_top)
|
||||||
|
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
||||||
|
# qsize: side length of the square
|
||||||
|
qsize = np.hypot(*x) * 2
|
||||||
|
|
||||||
|
quad_ori = np.copy(quad)
|
||||||
|
# Shrink, for large face
|
||||||
|
# TODO: do we really need shrink
|
||||||
|
shrink = int(np.floor(qsize / output_size * 0.5))
|
||||||
|
if shrink > 1:
|
||||||
|
h, w = img.shape[0:2]
|
||||||
|
rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
|
||||||
|
img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
|
||||||
|
quad /= shrink
|
||||||
|
qsize /= shrink
|
||||||
|
|
||||||
|
# Crop
|
||||||
|
h, w = img.shape[0:2]
|
||||||
|
border = max(int(np.rint(qsize * 0.1)), 3)
|
||||||
|
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
||||||
|
int(np.ceil(max(quad[:, 1]))))
|
||||||
|
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
|
||||||
|
if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
|
||||||
|
img = img[crop[1]:crop[3], crop[0]:crop[2], :]
|
||||||
|
quad -= crop[0:2]
|
||||||
|
|
||||||
|
# Pad
|
||||||
|
# pad: (width_left, height_top, width_right, height_bottom)
|
||||||
|
h, w = img.shape[0:2]
|
||||||
|
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
||||||
|
int(np.ceil(max(quad[:, 1]))))
|
||||||
|
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
|
||||||
|
if enable_padding and max(pad) > border - 4:
|
||||||
|
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
||||||
|
img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
||||||
|
h, w = img.shape[0:2]
|
||||||
|
y, x, _ = np.ogrid[:h, :w, :1]
|
||||||
|
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
||||||
|
np.float32(w - 1 - x) / pad[2]),
|
||||||
|
1.0 - np.minimum(np.float32(y) / pad[1],
|
||||||
|
np.float32(h - 1 - y) / pad[3]))
|
||||||
|
blur = int(qsize * 0.02)
|
||||||
|
if blur % 2 == 0:
|
||||||
|
blur += 1
|
||||||
|
blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
|
||||||
|
|
||||||
|
img = img.astype('float32')
|
||||||
|
img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
||||||
|
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
||||||
|
img = np.clip(img, 0, 255) # float32, [0, 255]
|
||||||
|
quad += pad[:2]
|
||||||
|
|
||||||
|
# Transform use cv2
|
||||||
|
h_ratio = shrink_ratio[0] / shrink_ratio[1]
|
||||||
|
dst_h, dst_w = int(transform_size * h_ratio), transform_size
|
||||||
|
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
|
||||||
|
# use cv2.LMEDS method for the equivalence to skimage transform
|
||||||
|
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
||||||
|
affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
|
||||||
|
cropped_face = cv2.warpAffine(
|
||||||
|
img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
|
||||||
|
|
||||||
|
if output_size < transform_size:
|
||||||
|
cropped_face = cv2.resize(
|
||||||
|
cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
if return_inverse_affine:
|
||||||
|
dst_h, dst_w = int(output_size * h_ratio), output_size
|
||||||
|
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
|
||||||
|
# use cv2.LMEDS method for the equivalence to skimage transform
|
||||||
|
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
||||||
|
affine_matrix = cv2.estimateAffinePartial2D(
|
||||||
|
quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
|
||||||
|
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
||||||
|
else:
|
||||||
|
inverse_affine = None
|
||||||
|
return cropped_face, inverse_affine
|
||||||
|
|
||||||
|
|
||||||
|
def paste_face_back(img, face, inverse_affine):
|
||||||
|
h, w = img.shape[0:2]
|
||||||
|
face_h, face_w = face.shape[0:2]
|
||||||
|
inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
|
||||||
|
mask = np.ones((face_h, face_w, 3), dtype=np.float32)
|
||||||
|
inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
|
||||||
|
# remove the black borders
|
||||||
|
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
|
||||||
|
inv_restored_remove_border = inv_mask_erosion * inv_restored
|
||||||
|
total_face_area = np.sum(inv_mask_erosion) // 3
|
||||||
|
# compute the fusion edge based on the area of face
|
||||||
|
w_edge = int(total_face_area**0.5) // 20
|
||||||
|
erosion_radius = w_edge * 2
|
||||||
|
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
||||||
|
blur_size = w_edge * 2
|
||||||
|
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
||||||
|
img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
|
||||||
|
# float32, [0, 255]
|
||||||
|
return img
|
118
iopaint/plugins/facexlib/utils/misc.py
Normal file
118
iopaint/plugins/facexlib/utils/misc.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import torch
|
||||||
|
from torch.hub import download_url_to_file, get_dir
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|
||||||
|
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||||
|
"""Write image to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): Image array to be written.
|
||||||
|
file_path (str): Image file path.
|
||||||
|
params (None or list): Same as opencv's :func:`imwrite` interface.
|
||||||
|
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
||||||
|
whether to create it automatically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Successful or not.
|
||||||
|
"""
|
||||||
|
if auto_mkdir:
|
||||||
|
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||||
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
return cv2.imwrite(file_path, img, params)
|
||||||
|
|
||||||
|
|
||||||
|
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||||
|
"""Numpy array to tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (list[ndarray] | ndarray): Input images.
|
||||||
|
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||||
|
float32 (bool): Whether to change to float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tensor] | tensor: Tensor images. If returned results only have
|
||||||
|
one element, just return tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _totensor(img, bgr2rgb, float32):
|
||||||
|
if img.shape[2] == 3 and bgr2rgb:
|
||||||
|
if img.dtype == 'float64':
|
||||||
|
img = img.astype('float32')
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||||
|
if float32:
|
||||||
|
img = img.float()
|
||||||
|
return img
|
||||||
|
|
||||||
|
if isinstance(imgs, list):
|
||||||
|
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||||
|
else:
|
||||||
|
return _totensor(imgs, bgr2rgb, float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None):
|
||||||
|
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||||
|
"""
|
||||||
|
if model_dir is None:
|
||||||
|
hub_dir = get_dir()
|
||||||
|
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||||
|
|
||||||
|
if save_dir is None:
|
||||||
|
save_dir = os.path.join(ROOT_DIR, model_dir)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
parts = urlparse(url)
|
||||||
|
filename = os.path.basename(parts.path)
|
||||||
|
if file_name is not None:
|
||||||
|
filename = file_name
|
||||||
|
cached_file = os.path.abspath(os.path.join(save_dir, filename))
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||||
|
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
|
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
||||||
|
"""Scan a directory to find the interested files.
|
||||||
|
Args:
|
||||||
|
dir_path (str): Path of the directory.
|
||||||
|
suffix (str | tuple(str), optional): File suffix that we are
|
||||||
|
interested in. Default: None.
|
||||||
|
recursive (bool, optional): If set to True, recursively scan the
|
||||||
|
directory. Default: False.
|
||||||
|
full_path (bool, optional): If set to True, include the dir_path.
|
||||||
|
Default: False.
|
||||||
|
Returns:
|
||||||
|
A generator for all the interested files with relative paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||||
|
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||||
|
|
||||||
|
root = dir_path
|
||||||
|
|
||||||
|
def _scandir(dir_path, suffix, recursive):
|
||||||
|
for entry in os.scandir(dir_path):
|
||||||
|
if not entry.name.startswith('.') and entry.is_file():
|
||||||
|
if full_path:
|
||||||
|
return_path = entry.path
|
||||||
|
else:
|
||||||
|
return_path = osp.relpath(entry.path, root)
|
||||||
|
|
||||||
|
if suffix is None:
|
||||||
|
yield return_path
|
||||||
|
elif return_path.endswith(suffix):
|
||||||
|
yield return_path
|
||||||
|
else:
|
||||||
|
if recursive:
|
||||||
|
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
322
iopaint/plugins/gfpgan/archs/gfpganv1_clean_arch.py
Normal file
322
iopaint/plugins/gfpgan/archs/gfpganv1_clean_arch.py
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||||
|
|
||||||
|
|
||||||
|
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||||
|
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||||
|
|
||||||
|
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_size (int): The spatial size of outputs.
|
||||||
|
num_style_feat (int): Channel number of style features. Default: 512.
|
||||||
|
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||||
|
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||||
|
narrow (float): The narrow ratio for channels. Default: 1.
|
||||||
|
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
||||||
|
super(StyleGAN2GeneratorCSFT, self).__init__(
|
||||||
|
out_size,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
num_mlp=num_mlp,
|
||||||
|
channel_multiplier=channel_multiplier,
|
||||||
|
narrow=narrow)
|
||||||
|
self.sft_half = sft_half
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
styles,
|
||||||
|
conditions,
|
||||||
|
input_is_latent=False,
|
||||||
|
noise=None,
|
||||||
|
randomize_noise=True,
|
||||||
|
truncation=1,
|
||||||
|
truncation_latent=None,
|
||||||
|
inject_index=None,
|
||||||
|
return_latents=False):
|
||||||
|
"""Forward function for StyleGAN2GeneratorCSFT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
styles (list[Tensor]): Sample codes of styles.
|
||||||
|
conditions (list[Tensor]): SFT conditions to generators.
|
||||||
|
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||||
|
noise (Tensor | None): Input noise or None. Default: None.
|
||||||
|
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||||
|
truncation (float): The truncation ratio. Default: 1.
|
||||||
|
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||||
|
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||||
|
return_latents (bool): Whether to return style latents. Default: False.
|
||||||
|
"""
|
||||||
|
# style codes -> latents with Style MLP layer
|
||||||
|
if not input_is_latent:
|
||||||
|
styles = [self.style_mlp(s) for s in styles]
|
||||||
|
# noises
|
||||||
|
if noise is None:
|
||||||
|
if randomize_noise:
|
||||||
|
noise = [None] * self.num_layers # for each style conv layer
|
||||||
|
else: # use the stored noise
|
||||||
|
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||||
|
# style truncation
|
||||||
|
if truncation < 1:
|
||||||
|
style_truncation = []
|
||||||
|
for style in styles:
|
||||||
|
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||||
|
styles = style_truncation
|
||||||
|
# get style latents with injection
|
||||||
|
if len(styles) == 1:
|
||||||
|
inject_index = self.num_latent
|
||||||
|
|
||||||
|
if styles[0].ndim < 3:
|
||||||
|
# repeat latent code for all the layers
|
||||||
|
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||||
|
else: # used for encoder with different latent code for each layer
|
||||||
|
latent = styles[0]
|
||||||
|
elif len(styles) == 2: # mixing noises
|
||||||
|
if inject_index is None:
|
||||||
|
inject_index = random.randint(1, self.num_latent - 1)
|
||||||
|
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||||
|
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||||
|
latent = torch.cat([latent1, latent2], 1)
|
||||||
|
|
||||||
|
# main generation
|
||||||
|
out = self.constant_input(latent.shape[0])
|
||||||
|
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||||
|
skip = self.to_rgb1(out, latent[:, 1])
|
||||||
|
|
||||||
|
i = 1
|
||||||
|
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||||
|
noise[2::2], self.to_rgbs):
|
||||||
|
out = conv1(out, latent[:, i], noise=noise1)
|
||||||
|
|
||||||
|
# the conditions may have fewer levels
|
||||||
|
if i < len(conditions):
|
||||||
|
# SFT part to combine the conditions
|
||||||
|
if self.sft_half: # only apply SFT to half of the channels
|
||||||
|
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||||
|
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||||
|
out = torch.cat([out_same, out_sft], dim=1)
|
||||||
|
else: # apply SFT to all the channels
|
||||||
|
out = out * conditions[i - 1] + conditions[i]
|
||||||
|
|
||||||
|
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||||
|
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||||
|
i += 2
|
||||||
|
|
||||||
|
image = skip
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return image, latent
|
||||||
|
else:
|
||||||
|
return image, None
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
"""Residual block with bilinear upsampling/downsampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Channel number of the input.
|
||||||
|
out_channels (int): Channel number of the output.
|
||||||
|
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, mode='down'):
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||||
|
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||||
|
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
||||||
|
if mode == 'down':
|
||||||
|
self.scale_factor = 0.5
|
||||||
|
elif mode == 'up':
|
||||||
|
self.scale_factor = 2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
||||||
|
# upsample/downsample
|
||||||
|
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||||
|
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
||||||
|
# skip
|
||||||
|
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||||
|
skip = self.skip(x)
|
||||||
|
out = out + skip
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GFPGANv1Clean(nn.Module):
|
||||||
|
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||||
|
|
||||||
|
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||||
|
|
||||||
|
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_size (int): The spatial size of outputs.
|
||||||
|
num_style_feat (int): Channel number of style features. Default: 512.
|
||||||
|
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||||
|
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||||
|
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||||
|
|
||||||
|
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||||
|
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||||
|
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||||
|
narrow (float): The narrow ratio for channels. Default: 1.
|
||||||
|
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_size,
|
||||||
|
num_style_feat=512,
|
||||||
|
channel_multiplier=1,
|
||||||
|
decoder_load_path=None,
|
||||||
|
fix_decoder=True,
|
||||||
|
# for stylegan decoder
|
||||||
|
num_mlp=8,
|
||||||
|
input_is_latent=False,
|
||||||
|
different_w=False,
|
||||||
|
narrow=1,
|
||||||
|
sft_half=False):
|
||||||
|
|
||||||
|
super(GFPGANv1Clean, self).__init__()
|
||||||
|
self.input_is_latent = input_is_latent
|
||||||
|
self.different_w = different_w
|
||||||
|
self.num_style_feat = num_style_feat
|
||||||
|
|
||||||
|
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||||
|
channels = {
|
||||||
|
'4': int(512 * unet_narrow),
|
||||||
|
'8': int(512 * unet_narrow),
|
||||||
|
'16': int(512 * unet_narrow),
|
||||||
|
'32': int(512 * unet_narrow),
|
||||||
|
'64': int(256 * channel_multiplier * unet_narrow),
|
||||||
|
'128': int(128 * channel_multiplier * unet_narrow),
|
||||||
|
'256': int(64 * channel_multiplier * unet_narrow),
|
||||||
|
'512': int(32 * channel_multiplier * unet_narrow),
|
||||||
|
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.log_size = int(math.log(out_size, 2))
|
||||||
|
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||||
|
|
||||||
|
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
||||||
|
|
||||||
|
# downsample
|
||||||
|
in_channels = channels[f'{first_out_size}']
|
||||||
|
self.conv_body_down = nn.ModuleList()
|
||||||
|
for i in range(self.log_size, 2, -1):
|
||||||
|
out_channels = channels[f'{2**(i - 1)}']
|
||||||
|
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
||||||
|
|
||||||
|
# upsample
|
||||||
|
in_channels = channels['4']
|
||||||
|
self.conv_body_up = nn.ModuleList()
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
out_channels = channels[f'{2**i}']
|
||||||
|
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
# to RGB
|
||||||
|
self.toRGB = nn.ModuleList()
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
||||||
|
|
||||||
|
if different_w:
|
||||||
|
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||||
|
else:
|
||||||
|
linear_out_channel = num_style_feat
|
||||||
|
|
||||||
|
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
||||||
|
|
||||||
|
# the decoder: stylegan2 generator with SFT modulations
|
||||||
|
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
||||||
|
out_size=out_size,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
num_mlp=num_mlp,
|
||||||
|
channel_multiplier=channel_multiplier,
|
||||||
|
narrow=narrow,
|
||||||
|
sft_half=sft_half)
|
||||||
|
|
||||||
|
# load pre-trained stylegan2 model if necessary
|
||||||
|
if decoder_load_path:
|
||||||
|
self.stylegan_decoder.load_state_dict(
|
||||||
|
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||||
|
# fix decoder without updating params
|
||||||
|
if fix_decoder:
|
||||||
|
for _, param in self.stylegan_decoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# for SFT modulations (scale and shift)
|
||||||
|
self.condition_scale = nn.ModuleList()
|
||||||
|
self.condition_shift = nn.ModuleList()
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
out_channels = channels[f'{2**i}']
|
||||||
|
if sft_half:
|
||||||
|
sft_out_channels = out_channels
|
||||||
|
else:
|
||||||
|
sft_out_channels = out_channels * 2
|
||||||
|
self.condition_scale.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||||
|
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||||
|
self.condition_shift.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||||
|
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||||
|
|
||||||
|
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
||||||
|
"""Forward function for GFPGANv1Clean.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input images.
|
||||||
|
return_latents (bool): Whether to return style latents. Default: False.
|
||||||
|
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||||
|
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||||
|
"""
|
||||||
|
conditions = []
|
||||||
|
unet_skips = []
|
||||||
|
out_rgbs = []
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
||||||
|
for i in range(self.log_size - 2):
|
||||||
|
feat = self.conv_body_down[i](feat)
|
||||||
|
unet_skips.insert(0, feat)
|
||||||
|
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
||||||
|
|
||||||
|
# style code
|
||||||
|
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||||
|
if self.different_w:
|
||||||
|
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||||
|
|
||||||
|
# decode
|
||||||
|
for i in range(self.log_size - 2):
|
||||||
|
# add unet skip
|
||||||
|
feat = feat + unet_skips[i]
|
||||||
|
# ResUpLayer
|
||||||
|
feat = self.conv_body_up[i](feat)
|
||||||
|
# generate scale and shift for SFT layers
|
||||||
|
scale = self.condition_scale[i](feat)
|
||||||
|
conditions.append(scale.clone())
|
||||||
|
shift = self.condition_shift[i](feat)
|
||||||
|
conditions.append(shift.clone())
|
||||||
|
# generate rgb images
|
||||||
|
if return_rgb:
|
||||||
|
out_rgbs.append(self.toRGB[i](feat))
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
image, _ = self.stylegan_decoder([style_code],
|
||||||
|
conditions,
|
||||||
|
return_latents=return_latents,
|
||||||
|
input_is_latent=self.input_is_latent,
|
||||||
|
randomize_noise=randomize_noise)
|
||||||
|
|
||||||
|
return image, out_rgbs
|
759
iopaint/plugins/gfpgan/archs/restoreformer_arch.py
Normal file
759
iopaint/plugins/gfpgan/archs/restoreformer_arch.py
Normal file
@ -0,0 +1,759 @@
|
|||||||
|
"""Modified from https://github.com/wzhouxiff/RestoreFormer"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class VectorQuantizer(nn.Module):
|
||||||
|
"""
|
||||||
|
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
||||||
|
____________________________________________
|
||||||
|
Discretization bottleneck part of the VQ-VAE.
|
||||||
|
Inputs:
|
||||||
|
- n_e : number of embeddings
|
||||||
|
- e_dim : dimension of embedding
|
||||||
|
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||||
|
_____________________________________________
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_e, e_dim, beta):
|
||||||
|
super(VectorQuantizer, self).__init__()
|
||||||
|
self.n_e = n_e
|
||||||
|
self.e_dim = e_dim
|
||||||
|
self.beta = beta
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||||
|
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
"""
|
||||||
|
Inputs the output of the encoder network z and maps it to a discrete
|
||||||
|
one-hot vector that is the index of the closest embedding vector e_j
|
||||||
|
z (continuous) -> z_q (discrete)
|
||||||
|
z.shape = (batch, channel, height, width)
|
||||||
|
quantization pipeline:
|
||||||
|
1. get encoder input (B,C,H,W)
|
||||||
|
2. flatten input to (B*H*W,C)
|
||||||
|
"""
|
||||||
|
# reshape z -> (batch, height, width, channel) and flatten
|
||||||
|
z = z.permute(0, 2, 3, 1).contiguous()
|
||||||
|
z_flattened = z.view(-1, self.e_dim)
|
||||||
|
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||||
|
|
||||||
|
d = (
|
||||||
|
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||||
|
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||||
|
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||||
|
)
|
||||||
|
|
||||||
|
# could possible replace this here
|
||||||
|
# #\start...
|
||||||
|
# find closest encodings
|
||||||
|
|
||||||
|
min_value, min_encoding_indices = torch.min(d, dim=1)
|
||||||
|
|
||||||
|
min_encoding_indices = min_encoding_indices.unsqueeze(1)
|
||||||
|
|
||||||
|
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
||||||
|
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||||
|
|
||||||
|
# dtype min encodings: torch.float32
|
||||||
|
# min_encodings shape: torch.Size([2048, 512])
|
||||||
|
# min_encoding_indices.shape: torch.Size([2048, 1])
|
||||||
|
|
||||||
|
# get quantized latent vectors
|
||||||
|
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||||
|
# .........\end
|
||||||
|
|
||||||
|
# with:
|
||||||
|
# .........\start
|
||||||
|
# min_encoding_indices = torch.argmin(d, dim=1)
|
||||||
|
# z_q = self.embedding(min_encoding_indices)
|
||||||
|
# ......\end......... (TODO)
|
||||||
|
|
||||||
|
# compute loss for embedding
|
||||||
|
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||||
|
(z_q - z.detach()) ** 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# preserve gradients
|
||||||
|
z_q = z + (z_q - z).detach()
|
||||||
|
|
||||||
|
# perplexity
|
||||||
|
|
||||||
|
e_mean = torch.mean(min_encodings, dim=0)
|
||||||
|
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||||
|
|
||||||
|
# reshape back to match original input shape
|
||||||
|
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
|
||||||
|
|
||||||
|
def get_codebook_entry(self, indices, shape):
|
||||||
|
# shape specifying (batch, height, width, channel)
|
||||||
|
# TODO: check for more easy handling with nn.Embedding
|
||||||
|
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
||||||
|
min_encodings.scatter_(1, indices[:, None], 1)
|
||||||
|
|
||||||
|
# get quantized latent vectors
|
||||||
|
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||||
|
|
||||||
|
if shape is not None:
|
||||||
|
z_q = z_q.view(shape)
|
||||||
|
|
||||||
|
# reshape back to match original input shape
|
||||||
|
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
return z_q
|
||||||
|
|
||||||
|
|
||||||
|
# pytorch_diffusion + derived encoder decoder
|
||||||
|
def nonlinearity(x):
|
||||||
|
# swish
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels):
|
||||||
|
return torch.nn.GroupNorm(
|
||||||
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.with_conv:
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
else:
|
||||||
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout,
|
||||||
|
temb_channels=512,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
|
self.norm1 = Normalize(in_channels)
|
||||||
|
self.conv1 = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
if temb_channels > 0:
|
||||||
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = Normalize(out_channels)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
self.conv2 = torch.nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, head_size=1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.head_size = head_size
|
||||||
|
self.att_size = in_channels // head_size
|
||||||
|
assert (
|
||||||
|
in_channels % head_size == 0
|
||||||
|
), "The size of head should be divided by the number of channels."
|
||||||
|
|
||||||
|
self.norm1 = Normalize(in_channels)
|
||||||
|
self.norm2 = Normalize(in_channels)
|
||||||
|
|
||||||
|
self.q = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.k = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.v = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.num = 0
|
||||||
|
|
||||||
|
def forward(self, x, y=None):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm1(h_)
|
||||||
|
if y is None:
|
||||||
|
y = h_
|
||||||
|
else:
|
||||||
|
y = self.norm2(y)
|
||||||
|
|
||||||
|
q = self.q(y)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q = q.reshape(b, self.head_size, self.att_size, h * w)
|
||||||
|
q = q.permute(0, 3, 1, 2) # b, hw, head, att
|
||||||
|
|
||||||
|
k = k.reshape(b, self.head_size, self.att_size, h * w)
|
||||||
|
k = k.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
v = v.reshape(b, self.head_size, self.att_size, h * w)
|
||||||
|
v = v.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2).transpose(2, 3)
|
||||||
|
|
||||||
|
scale = int(self.att_size) ** (-0.5)
|
||||||
|
q.mul_(scale)
|
||||||
|
w_ = torch.matmul(q, k)
|
||||||
|
w_ = F.softmax(w_, dim=3)
|
||||||
|
|
||||||
|
w_ = w_.matmul(v)
|
||||||
|
|
||||||
|
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
|
||||||
|
w_ = w_.view(b, h, w, -1)
|
||||||
|
w_ = w_.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
w_ = self.proj_out(w_)
|
||||||
|
|
||||||
|
return x + w_
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_resolutions=(16,),
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels=3,
|
||||||
|
resolution=512,
|
||||||
|
z_channels=256,
|
||||||
|
double_z=True,
|
||||||
|
enable_mid=True,
|
||||||
|
head_size=1,
|
||||||
|
**ignore_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.enable_mid = enable_mid
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = torch.nn.Conv2d(
|
||||||
|
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
if self.enable_mid:
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = torch.nn.Conv2d(
|
||||||
|
block_in,
|
||||||
|
2 * z_channels if double_z else z_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hs = {}
|
||||||
|
# timestep embedding
|
||||||
|
temb = None
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
h = self.conv_in(x)
|
||||||
|
hs["in"] = h
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](h, temb)
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
# hs.append(h)
|
||||||
|
hs["block_" + str(i_level)] = h
|
||||||
|
h = self.down[i_level].downsample(h)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
# h = hs[-1]
|
||||||
|
if self.enable_mid:
|
||||||
|
h = self.mid.block_1(h, temb)
|
||||||
|
hs["block_" + str(i_level) + "_atten"] = h
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h, temb)
|
||||||
|
hs["mid_atten"] = h
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
# hs.append(h)
|
||||||
|
hs["out"] = h
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_resolutions=(16,),
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels=3,
|
||||||
|
resolution=512,
|
||||||
|
z_channels=256,
|
||||||
|
give_pre_end=False,
|
||||||
|
enable_mid=True,
|
||||||
|
head_size=1,
|
||||||
|
**ignorekwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.give_pre_end = give_pre_end
|
||||||
|
self.enable_mid = enable_mid
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = torch.nn.Conv2d(
|
||||||
|
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
if self.enable_mid:
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = torch.nn.Conv2d(
|
||||||
|
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
# assert z.shape[1:] == self.z_shape[1:]
|
||||||
|
self.last_z_shape = z.shape
|
||||||
|
|
||||||
|
# timestep embedding
|
||||||
|
temb = None
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
if self.enable_mid:
|
||||||
|
h = self.mid.block_1(h, temb)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h, temb)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h, temb)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
if self.give_pre_end:
|
||||||
|
return h
|
||||||
|
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadDecoderTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_resolutions=(16,),
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels=3,
|
||||||
|
resolution=512,
|
||||||
|
z_channels=256,
|
||||||
|
give_pre_end=False,
|
||||||
|
enable_mid=True,
|
||||||
|
head_size=1,
|
||||||
|
**ignorekwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.give_pre_end = give_pre_end
|
||||||
|
self.enable_mid = enable_mid
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = torch.nn.Conv2d(
|
||||||
|
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
if self.enable_mid:
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = torch.nn.Conv2d(
|
||||||
|
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, z, hs):
|
||||||
|
# assert z.shape[1:] == self.z_shape[1:]
|
||||||
|
# self.last_z_shape = z.shape
|
||||||
|
|
||||||
|
# timestep embedding
|
||||||
|
temb = None
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
if self.enable_mid:
|
||||||
|
h = self.mid.block_1(h, temb)
|
||||||
|
h = self.mid.attn_1(h, hs["mid_atten"])
|
||||||
|
h = self.mid.block_2(h, temb)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h, temb)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](
|
||||||
|
h, hs["block_" + str(i_level) + "_atten"]
|
||||||
|
)
|
||||||
|
# hfeature = h.clone()
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
if self.give_pre_end:
|
||||||
|
return h
|
||||||
|
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class RestoreFormer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_embed=1024,
|
||||||
|
embed_dim=256,
|
||||||
|
ch=64,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=(1, 2, 2, 4, 4, 8),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_resolutions=(16,),
|
||||||
|
dropout=0.0,
|
||||||
|
in_channels=3,
|
||||||
|
resolution=512,
|
||||||
|
z_channels=256,
|
||||||
|
double_z=False,
|
||||||
|
enable_mid=True,
|
||||||
|
fix_decoder=False,
|
||||||
|
fix_codebook=True,
|
||||||
|
fix_encoder=False,
|
||||||
|
head_size=8,
|
||||||
|
):
|
||||||
|
super(RestoreFormer, self).__init__()
|
||||||
|
|
||||||
|
self.encoder = MultiHeadEncoder(
|
||||||
|
ch=ch,
|
||||||
|
out_ch=out_ch,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
dropout=dropout,
|
||||||
|
in_channels=in_channels,
|
||||||
|
resolution=resolution,
|
||||||
|
z_channels=z_channels,
|
||||||
|
double_z=double_z,
|
||||||
|
enable_mid=enable_mid,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
self.decoder = MultiHeadDecoderTransformer(
|
||||||
|
ch=ch,
|
||||||
|
out_ch=out_ch,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
dropout=dropout,
|
||||||
|
in_channels=in_channels,
|
||||||
|
resolution=resolution,
|
||||||
|
z_channels=z_channels,
|
||||||
|
enable_mid=enable_mid,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
|
||||||
|
|
||||||
|
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||||
|
|
||||||
|
if fix_decoder:
|
||||||
|
for _, param in self.decoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
for _, param in self.post_quant_conv.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
for _, param in self.quantize.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
elif fix_codebook:
|
||||||
|
for _, param in self.quantize.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if fix_encoder:
|
||||||
|
for _, param in self.encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
hs = self.encoder(x)
|
||||||
|
h = self.quant_conv(hs["out"])
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
return quant, emb_loss, info, hs
|
||||||
|
|
||||||
|
def decode(self, quant, hs):
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant, hs)
|
||||||
|
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, **kwargs):
|
||||||
|
quant, diff, info, hs = self.encode(input)
|
||||||
|
dec = self.decode(quant, hs)
|
||||||
|
|
||||||
|
return dec, None
|
434
iopaint/plugins/gfpgan/archs/stylegan2_clean_arch.py
Normal file
434
iopaint/plugins/gfpgan/archs/stylegan2_clean_arch.py
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from iopaint.plugins.basicsr.arch_util import default_init_weights
|
||||||
|
|
||||||
|
|
||||||
|
class NormStyleCode(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
"""Normalize the style codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Style codes with shape (b, c).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Normalized tensor.
|
||||||
|
"""
|
||||||
|
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedConv2d(nn.Module):
|
||||||
|
"""Modulated Conv2d used in StyleGAN2.
|
||||||
|
|
||||||
|
There is no bias in ModulatedConv2d.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Channel number of the input.
|
||||||
|
out_channels (int): Channel number of the output.
|
||||||
|
kernel_size (int): Size of the convolving kernel.
|
||||||
|
num_style_feat (int): Channel number of style features.
|
||||||
|
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
||||||
|
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||||
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
num_style_feat,
|
||||||
|
demodulate=True,
|
||||||
|
sample_mode=None,
|
||||||
|
eps=1e-8,
|
||||||
|
):
|
||||||
|
super(ModulatedConv2d, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.demodulate = demodulate
|
||||||
|
self.sample_mode = sample_mode
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
# modulation inside each modulated conv
|
||||||
|
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
||||||
|
# initialization
|
||||||
|
default_init_weights(
|
||||||
|
self.modulation,
|
||||||
|
scale=1,
|
||||||
|
bias_fill=1,
|
||||||
|
a=0,
|
||||||
|
mode="fan_in",
|
||||||
|
nonlinearity="linear",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
|
||||||
|
/ math.sqrt(in_channels * kernel_size**2)
|
||||||
|
)
|
||||||
|
self.padding = kernel_size // 2
|
||||||
|
|
||||||
|
def forward(self, x, style):
|
||||||
|
"""Forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Tensor with shape (b, c, h, w).
|
||||||
|
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Modulated tensor after convolution.
|
||||||
|
"""
|
||||||
|
b, c, h, w = x.shape # c = c_in
|
||||||
|
# weight modulation
|
||||||
|
style = self.modulation(style).view(b, 1, c, 1, 1)
|
||||||
|
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
||||||
|
weight = self.weight * style # (b, c_out, c_in, k, k)
|
||||||
|
|
||||||
|
if self.demodulate:
|
||||||
|
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
||||||
|
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
||||||
|
|
||||||
|
weight = weight.view(
|
||||||
|
b * self.out_channels, c, self.kernel_size, self.kernel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsample or downsample if necessary
|
||||||
|
if self.sample_mode == "upsample":
|
||||||
|
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
|
||||||
|
elif self.sample_mode == "downsample":
|
||||||
|
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x = x.view(1, b * c, h, w)
|
||||||
|
# weight: (b*c_out, c_in, k, k), groups=b
|
||||||
|
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
||||||
|
out = out.view(b, self.out_channels, *out.shape[2:4])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
|
||||||
|
f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StyleConv(nn.Module):
|
||||||
|
"""Style conv used in StyleGAN2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Channel number of the input.
|
||||||
|
out_channels (int): Channel number of the output.
|
||||||
|
kernel_size (int): Size of the convolving kernel.
|
||||||
|
num_style_feat (int): Channel number of style features.
|
||||||
|
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||||
|
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
num_style_feat,
|
||||||
|
demodulate=True,
|
||||||
|
sample_mode=None,
|
||||||
|
):
|
||||||
|
super(StyleConv, self).__init__()
|
||||||
|
self.modulated_conv = ModulatedConv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
num_style_feat,
|
||||||
|
demodulate=demodulate,
|
||||||
|
sample_mode=sample_mode,
|
||||||
|
)
|
||||||
|
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
||||||
|
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
||||||
|
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x, style, noise=None):
|
||||||
|
# modulate
|
||||||
|
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
||||||
|
# noise injection
|
||||||
|
if noise is None:
|
||||||
|
b, _, h, w = out.shape
|
||||||
|
noise = out.new_empty(b, 1, h, w).normal_()
|
||||||
|
out = out + self.weight * noise
|
||||||
|
# add bias
|
||||||
|
out = out + self.bias
|
||||||
|
# activation
|
||||||
|
out = self.activate(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ToRGB(nn.Module):
|
||||||
|
"""To RGB (image space) from features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Channel number of input.
|
||||||
|
num_style_feat (int): Channel number of style features.
|
||||||
|
upsample (bool): Whether to upsample. Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, num_style_feat, upsample=True):
|
||||||
|
super(ToRGB, self).__init__()
|
||||||
|
self.upsample = upsample
|
||||||
|
self.modulated_conv = ModulatedConv2d(
|
||||||
|
in_channels,
|
||||||
|
3,
|
||||||
|
kernel_size=1,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
demodulate=False,
|
||||||
|
sample_mode=None,
|
||||||
|
)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x, style, skip=None):
|
||||||
|
"""Forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Feature tensor with shape (b, c, h, w).
|
||||||
|
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||||
|
skip (Tensor): Base/skip tensor. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: RGB images.
|
||||||
|
"""
|
||||||
|
out = self.modulated_conv(x, style)
|
||||||
|
out = out + self.bias
|
||||||
|
if skip is not None:
|
||||||
|
if self.upsample:
|
||||||
|
skip = F.interpolate(
|
||||||
|
skip, scale_factor=2, mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
out = out + skip
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantInput(nn.Module):
|
||||||
|
"""Constant input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_channel (int): Channel number of constant input.
|
||||||
|
size (int): Spatial size of constant input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_channel, size):
|
||||||
|
super(ConstantInput, self).__init__()
|
||||||
|
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
out = self.weight.repeat(batch, 1, 1, 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class StyleGAN2GeneratorClean(nn.Module):
|
||||||
|
"""Clean version of StyleGAN2 Generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_size (int): The spatial size of outputs.
|
||||||
|
num_style_feat (int): Channel number of style features. Default: 512.
|
||||||
|
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||||
|
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||||
|
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
|
||||||
|
):
|
||||||
|
super(StyleGAN2GeneratorClean, self).__init__()
|
||||||
|
# Style MLP layers
|
||||||
|
self.num_style_feat = num_style_feat
|
||||||
|
style_mlp_layers = [NormStyleCode()]
|
||||||
|
for i in range(num_mlp):
|
||||||
|
style_mlp_layers.extend(
|
||||||
|
[
|
||||||
|
nn.Linear(num_style_feat, num_style_feat, bias=True),
|
||||||
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
||||||
|
# initialization
|
||||||
|
default_init_weights(
|
||||||
|
self.style_mlp,
|
||||||
|
scale=1,
|
||||||
|
bias_fill=0,
|
||||||
|
a=0.2,
|
||||||
|
mode="fan_in",
|
||||||
|
nonlinearity="leaky_relu",
|
||||||
|
)
|
||||||
|
|
||||||
|
# channel list
|
||||||
|
channels = {
|
||||||
|
"4": int(512 * narrow),
|
||||||
|
"8": int(512 * narrow),
|
||||||
|
"16": int(512 * narrow),
|
||||||
|
"32": int(512 * narrow),
|
||||||
|
"64": int(256 * channel_multiplier * narrow),
|
||||||
|
"128": int(128 * channel_multiplier * narrow),
|
||||||
|
"256": int(64 * channel_multiplier * narrow),
|
||||||
|
"512": int(32 * channel_multiplier * narrow),
|
||||||
|
"1024": int(16 * channel_multiplier * narrow),
|
||||||
|
}
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
self.constant_input = ConstantInput(channels["4"], size=4)
|
||||||
|
self.style_conv1 = StyleConv(
|
||||||
|
channels["4"],
|
||||||
|
channels["4"],
|
||||||
|
kernel_size=3,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
demodulate=True,
|
||||||
|
sample_mode=None,
|
||||||
|
)
|
||||||
|
self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
|
||||||
|
|
||||||
|
self.log_size = int(math.log(out_size, 2))
|
||||||
|
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||||
|
self.num_latent = self.log_size * 2 - 2
|
||||||
|
|
||||||
|
self.style_convs = nn.ModuleList()
|
||||||
|
self.to_rgbs = nn.ModuleList()
|
||||||
|
self.noises = nn.Module()
|
||||||
|
|
||||||
|
in_channels = channels["4"]
|
||||||
|
# noise
|
||||||
|
for layer_idx in range(self.num_layers):
|
||||||
|
resolution = 2 ** ((layer_idx + 5) // 2)
|
||||||
|
shape = [1, 1, resolution, resolution]
|
||||||
|
self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
|
||||||
|
# style convs and to_rgbs
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
out_channels = channels[f"{2 ** i}"]
|
||||||
|
self.style_convs.append(
|
||||||
|
StyleConv(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
demodulate=True,
|
||||||
|
sample_mode="upsample",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.style_convs.append(
|
||||||
|
StyleConv(
|
||||||
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
demodulate=True,
|
||||||
|
sample_mode=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
def make_noise(self):
|
||||||
|
"""Make noise for noise injection."""
|
||||||
|
device = self.constant_input.weight.device
|
||||||
|
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
||||||
|
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
for _ in range(2):
|
||||||
|
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
||||||
|
|
||||||
|
return noises
|
||||||
|
|
||||||
|
def get_latent(self, x):
|
||||||
|
return self.style_mlp(x)
|
||||||
|
|
||||||
|
def mean_latent(self, num_latent):
|
||||||
|
latent_in = torch.randn(
|
||||||
|
num_latent, self.num_style_feat, device=self.constant_input.weight.device
|
||||||
|
)
|
||||||
|
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
styles,
|
||||||
|
input_is_latent=False,
|
||||||
|
noise=None,
|
||||||
|
randomize_noise=True,
|
||||||
|
truncation=1,
|
||||||
|
truncation_latent=None,
|
||||||
|
inject_index=None,
|
||||||
|
return_latents=False,
|
||||||
|
):
|
||||||
|
"""Forward function for StyleGAN2GeneratorClean.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
styles (list[Tensor]): Sample codes of styles.
|
||||||
|
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||||
|
noise (Tensor | None): Input noise or None. Default: None.
|
||||||
|
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||||
|
truncation (float): The truncation ratio. Default: 1.
|
||||||
|
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||||
|
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||||
|
return_latents (bool): Whether to return style latents. Default: False.
|
||||||
|
"""
|
||||||
|
# style codes -> latents with Style MLP layer
|
||||||
|
if not input_is_latent:
|
||||||
|
styles = [self.style_mlp(s) for s in styles]
|
||||||
|
# noises
|
||||||
|
if noise is None:
|
||||||
|
if randomize_noise:
|
||||||
|
noise = [None] * self.num_layers # for each style conv layer
|
||||||
|
else: # use the stored noise
|
||||||
|
noise = [
|
||||||
|
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
||||||
|
]
|
||||||
|
# style truncation
|
||||||
|
if truncation < 1:
|
||||||
|
style_truncation = []
|
||||||
|
for style in styles:
|
||||||
|
style_truncation.append(
|
||||||
|
truncation_latent + truncation * (style - truncation_latent)
|
||||||
|
)
|
||||||
|
styles = style_truncation
|
||||||
|
# get style latents with injection
|
||||||
|
if len(styles) == 1:
|
||||||
|
inject_index = self.num_latent
|
||||||
|
|
||||||
|
if styles[0].ndim < 3:
|
||||||
|
# repeat latent code for all the layers
|
||||||
|
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||||
|
else: # used for encoder with different latent code for each layer
|
||||||
|
latent = styles[0]
|
||||||
|
elif len(styles) == 2: # mixing noises
|
||||||
|
if inject_index is None:
|
||||||
|
inject_index = random.randint(1, self.num_latent - 1)
|
||||||
|
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||||
|
latent2 = (
|
||||||
|
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||||
|
)
|
||||||
|
latent = torch.cat([latent1, latent2], 1)
|
||||||
|
|
||||||
|
# main generation
|
||||||
|
out = self.constant_input(latent.shape[0])
|
||||||
|
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||||
|
skip = self.to_rgb1(out, latent[:, 1])
|
||||||
|
|
||||||
|
i = 1
|
||||||
|
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||||
|
self.style_convs[::2],
|
||||||
|
self.style_convs[1::2],
|
||||||
|
noise[1::2],
|
||||||
|
noise[2::2],
|
||||||
|
self.to_rgbs,
|
||||||
|
):
|
||||||
|
out = conv1(out, latent[:, i], noise=noise1)
|
||||||
|
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||||
|
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||||
|
i += 2
|
||||||
|
|
||||||
|
image = skip
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return image, latent
|
||||||
|
else:
|
||||||
|
return image, None
|
@ -20,11 +20,6 @@ class GFPGANPlugin(BasePlugin):
|
|||||||
model_path = download_model(url, model_md5)
|
model_path = download_model(url, model_md5)
|
||||||
logger.info(f"GFPGAN model path: {model_path}")
|
logger.info(f"GFPGAN model path: {model_path}")
|
||||||
|
|
||||||
import facexlib
|
|
||||||
|
|
||||||
if hasattr(facexlib.detection.retinaface, "device"):
|
|
||||||
facexlib.detection.retinaface.device = device
|
|
||||||
|
|
||||||
# Use GFPGAN for face enhancement
|
# Use GFPGAN for face enhancement
|
||||||
self.face_enhancer = MyGFPGANer(
|
self.face_enhancer = MyGFPGANer(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
@ -64,11 +59,3 @@ class GFPGANPlugin(BasePlugin):
|
|||||||
# except Exception as error:
|
# except Exception as error:
|
||||||
# print("wrong scale input.", error)
|
# print("wrong scale input.", error)
|
||||||
return bgr_output
|
return bgr_output
|
||||||
|
|
||||||
def check_dep(self):
|
|
||||||
try:
|
|
||||||
import gfpgan
|
|
||||||
except ImportError:
|
|
||||||
return (
|
|
||||||
"gfpgan is not installed, please install it first. pip install gfpgan"
|
|
||||||
)
|
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
from torchvision.transforms.functional import normalize
|
||||||
from gfpgan import GFPGANv1Clean, GFPGANer
|
|
||||||
from torch.hub import get_dir
|
from torch.hub import get_dir
|
||||||
|
|
||||||
|
from .facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
from .gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||||
|
from .basicsr.img_util import img2tensor, tensor2img
|
||||||
|
|
||||||
class MyGFPGANer(GFPGANer):
|
|
||||||
|
class MyGFPGANer:
|
||||||
"""Helper for restoration with GFPGAN.
|
"""Helper for restoration with GFPGAN.
|
||||||
|
|
||||||
It will detect and crop faces, and then resize the faces to 512x512.
|
It will detect and crop faces, and then resize the faces to 512x512.
|
||||||
@ -55,7 +59,7 @@ class MyGFPGANer(GFPGANer):
|
|||||||
sft_half=True,
|
sft_half=True,
|
||||||
)
|
)
|
||||||
elif arch == "RestoreFormer":
|
elif arch == "RestoreFormer":
|
||||||
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
from .gfpgan.archs.restoreformer_arch import RestoreFormer
|
||||||
|
|
||||||
self.gfpgan = RestoreFormer()
|
self.gfpgan = RestoreFormer()
|
||||||
|
|
||||||
@ -82,3 +86,71 @@ class MyGFPGANer(GFPGANer):
|
|||||||
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
||||||
self.gfpgan.eval()
|
self.gfpgan.eval()
|
||||||
self.gfpgan = self.gfpgan.to(self.device)
|
self.gfpgan = self.gfpgan.to(self.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def enhance(
|
||||||
|
self,
|
||||||
|
img,
|
||||||
|
has_aligned=False,
|
||||||
|
only_center_face=False,
|
||||||
|
paste_back=True,
|
||||||
|
weight=0.5,
|
||||||
|
):
|
||||||
|
self.face_helper.clean_all()
|
||||||
|
|
||||||
|
if has_aligned: # the inputs are already aligned
|
||||||
|
img = cv2.resize(img, (512, 512))
|
||||||
|
self.face_helper.cropped_faces = [img]
|
||||||
|
else:
|
||||||
|
self.face_helper.read_image(img)
|
||||||
|
# get face landmarks for each face
|
||||||
|
self.face_helper.get_face_landmarks_5(
|
||||||
|
only_center_face=only_center_face, eye_dist_threshold=5
|
||||||
|
)
|
||||||
|
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||||
|
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||||
|
# align and warp each face
|
||||||
|
self.face_helper.align_warp_face()
|
||||||
|
|
||||||
|
# face restoration
|
||||||
|
for cropped_face in self.face_helper.cropped_faces:
|
||||||
|
# prepare data
|
||||||
|
cropped_face_t = img2tensor(
|
||||||
|
cropped_face / 255.0, bgr2rgb=True, float32=True
|
||||||
|
)
|
||||||
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
||||||
|
# convert to image
|
||||||
|
restored_face = tensor2img(
|
||||||
|
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
|
||||||
|
)
|
||||||
|
except RuntimeError as error:
|
||||||
|
print(f"\tFailed inference for GFPGAN: {error}.")
|
||||||
|
restored_face = cropped_face
|
||||||
|
|
||||||
|
restored_face = restored_face.astype("uint8")
|
||||||
|
self.face_helper.add_restored_face(restored_face)
|
||||||
|
|
||||||
|
if not has_aligned and paste_back:
|
||||||
|
# upsample the background
|
||||||
|
if self.bg_upsampler is not None:
|
||||||
|
# Now only support RealESRGAN for upsampling background
|
||||||
|
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||||
|
else:
|
||||||
|
bg_img = None
|
||||||
|
|
||||||
|
self.face_helper.get_inverse_affine(None)
|
||||||
|
# paste each restored face to the input image
|
||||||
|
restored_img = self.face_helper.paste_faces_to_input_image(
|
||||||
|
upsample_img=bg_img
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
self.face_helper.cropped_faces,
|
||||||
|
self.face_helper.restored_faces,
|
||||||
|
restored_img,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
||||||
|
@ -466,6 +466,3 @@ class RealESRGANUpscaler(BasePlugin):
|
|||||||
# 输出是 BGR
|
# 输出是 BGR
|
||||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||||
return upsampled
|
return upsampled
|
||||||
|
|
||||||
def check_dep(self):
|
|
||||||
pass
|
|
||||||
|
@ -20,11 +20,6 @@ class RestoreFormerPlugin(BasePlugin):
|
|||||||
model_path = download_model(url, model_md5)
|
model_path = download_model(url, model_md5)
|
||||||
logger.info(f"RestoreFormer model path: {model_path}")
|
logger.info(f"RestoreFormer model path: {model_path}")
|
||||||
|
|
||||||
import facexlib
|
|
||||||
|
|
||||||
if hasattr(facexlib.detection.retinaface, "device"):
|
|
||||||
facexlib.detection.retinaface.device = device
|
|
||||||
|
|
||||||
self.face_enhancer = MyGFPGANer(
|
self.face_enhancer = MyGFPGANer(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
upscale=1,
|
upscale=1,
|
||||||
@ -47,11 +42,3 @@ class RestoreFormerPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
|
logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
|
||||||
return bgr_output
|
return bgr_output
|
||||||
|
|
||||||
def check_dep(self):
|
|
||||||
try:
|
|
||||||
import gfpgan
|
|
||||||
except ImportError:
|
|
||||||
return (
|
|
||||||
"gfpgan is not installed, please install it first. pip install gfpgan"
|
|
||||||
)
|
|
||||||
|
@ -30,7 +30,6 @@ _CANDIDATES = [
|
|||||||
"accelerate",
|
"accelerate",
|
||||||
"iopaint",
|
"iopaint",
|
||||||
"rembg",
|
"rembg",
|
||||||
"gfpgan",
|
|
||||||
]
|
]
|
||||||
# Check once at runtime
|
# Check once at runtime
|
||||||
for name in _CANDIDATES:
|
for name in _CANDIDATES:
|
||||||
|
@ -26,6 +26,11 @@ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
|||||||
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
|
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
|
||||||
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
|
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
|
||||||
|
|
||||||
|
person_p = current_dir / "image.png"
|
||||||
|
person_bgr_img = cv2.imread(str(person_p))
|
||||||
|
person_rgb_img = cv2.cvtColor(person_bgr_img, cv2.COLOR_BGR2RGB)
|
||||||
|
person_rgb_img = cv2.resize(person_rgb_img, (512, 512))
|
||||||
|
|
||||||
|
|
||||||
def _save(img, name):
|
def _save(img, name):
|
||||||
cv2.imwrite(str(save_dir / name), img)
|
cv2.imwrite(str(save_dir / name), img)
|
||||||
@ -86,7 +91,7 @@ def test_gfpgan(device):
|
|||||||
check_device(device)
|
check_device(device)
|
||||||
model = GFPGANPlugin(device)
|
model = GFPGANPlugin(device)
|
||||||
res = model.gen_image(
|
res = model.gen_image(
|
||||||
rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
|
person_rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
|
||||||
)
|
)
|
||||||
_save(res, f"test_gfpgan_{device}.png")
|
_save(res, f"test_gfpgan_{device}.png")
|
||||||
|
|
||||||
@ -96,7 +101,8 @@ def test_restoreformer(device):
|
|||||||
check_device(device)
|
check_device(device)
|
||||||
model = RestoreFormerPlugin(device)
|
model = RestoreFormerPlugin(device)
|
||||||
res = model.gen_image(
|
res = model.gen_image(
|
||||||
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
|
person_rgb_img,
|
||||||
|
RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64),
|
||||||
)
|
)
|
||||||
_save(res, f"test_restoreformer_{device}.png")
|
_save(res, f"test_restoreformer_{device}.png")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user