add GFPGAN model

This commit is contained in:
Qing 2023-03-26 13:39:09 +08:00
parent d938f2da3c
commit e7c7896bfa
7 changed files with 171 additions and 1 deletions

View File

@ -101,6 +101,14 @@ def parse_args():
type=str,
choices=RealESRGANModelNameList,
)
parser.add_argument(
"--enable-gfpgan",
action="store_true",
help="Enable GFPGAN face restore",
)
parser.add_argument(
"--gfpgan-device", default="cpu", type=str, choices=["cpu", "cuda"]
)
parser.add_argument(
"--enable-gif",
action="store_true",

View File

@ -1,4 +1,5 @@
from .interactive_seg import InteractiveSeg, Click
from .remove_bg import RemoveBG
from .realesrgan import RealESRGANUpscaler
from .gfpgan_plugin import GFPGANPlugin
from .gif import MakeGIF

View File

@ -0,0 +1,61 @@
import cv2
from loguru import logger
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
class GFPGANPlugin(BasePlugin):
name = "GFPGAN"
def __init__(self, device):
super().__init__()
from .gfpganer import MyGFPGANer
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
model_md5 = "94d735072630ab734561130a47bc44f8"
model_path = download_model(url, model_md5)
logger.info(f"GFPGAN model path: {model_path}")
# Use GFPGAN for face enhancement
self.face_enhancer = MyGFPGANer(
model_path=model_path,
upscale=2,
arch="clean",
channel_multiplier=2,
device=device,
)
def __call__(self, rgb_np_img, files, form):
weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
_, _, bgr_output = self.face_enhancer.enhance(
bgr_np_img,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=weight,
)
logger.info(f"GFPGAN output shape: {bgr_output.shape}")
# try:
# if scale != 2:
# interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
# h, w = img.shape[0:2]
# output = cv2.resize(
# output,
# (int(w * scale / 2), int(h * scale / 2)),
# interpolation=interpolation,
# )
# except Exception as error:
# print("wrong scale input.", error)
return bgr_output
def check_dep(self):
try:
import gfpgan
except ImportError:
return (
"gfpgan is not installed, please install it first. pip install gfpgan"
)

View File

@ -0,0 +1,84 @@
import os
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from gfpgan import GFPGANv1Clean, GFPGANer
from torch.hub import get_dir
class MyGFPGANer(GFPGANer):
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
GFPGAN is used to restored the resized faces.
The background is upsampled with the bg_upsampler.
Finally, the faces will be pasted back to the upsample background image.
Args:
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
upscale (float): The upscale of the final output. Default: 2.
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(
self,
model_path,
upscale=2,
arch="clean",
channel_multiplier=2,
bg_upsampler=None,
device=None,
):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
# initialize the GFP-GAN
if arch == "clean":
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True,
)
elif arch == "RestoreFormer":
from gfpgan.archs.restoreformer_arch import RestoreFormer
self.gfpgan = RestoreFormer()
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model="retinaface_resnet50",
save_ext="png",
use_parse=True,
device=self.device,
model_rootpath=model_dir,
)
loadnet = torch.load(model_path)
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)

View File

@ -71,6 +71,7 @@ class RealESRGANUpscaler(BasePlugin):
model_info = REAL_ESRGAN_MODELS[name]
model_path = download_model(model_info["url"], model_info["model_md5"])
logger.info(f"RealESRGAN model path: {model_path}")
self.model = RealESRGANer(
scale=model_info["scale"],

View File

@ -1,5 +1,7 @@
import os
import cv2
import numpy as np
from torch.hub import get_dir
from lama_cleaner.plugins.base_plugin import BasePlugin
@ -11,6 +13,10 @@ class RemoveBG(BasePlugin):
super().__init__()
from rembg import new_session
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.environ["U2NET_HOME"] = model_dir
self.session = new_session(model_name="u2net")
def __call__(self, rgb_np_img, files, form):

View File

@ -19,7 +19,13 @@ from lama_cleaner.const import SD15_MODELS
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler, MakeGIF
from lama_cleaner.plugins import (
InteractiveSeg,
RemoveBG,
RealESRGANUpscaler,
MakeGIF,
GFPGANPlugin,
)
from lama_cleaner.schema import Config
try:
@ -423,6 +429,9 @@ def build_plugins(args):
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
args.realesrgan_model, args.realesrgan_device
)
if args.enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
plugins[GFPGANPlugin.name] = GFPGANPlugin(args.gfpgan_device)
if args.enable_gif:
logger.info(f"Initialize GIF plugin")
plugins[MakeGIF.name] = MakeGIF()