diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index a4eefd7..df9fe98 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -101,6 +101,14 @@ def parse_args(): type=str, choices=RealESRGANModelNameList, ) + parser.add_argument( + "--enable-gfpgan", + action="store_true", + help="Enable GFPGAN face restore", + ) + parser.add_argument( + "--gfpgan-device", default="cpu", type=str, choices=["cpu", "cuda"] + ) parser.add_argument( "--enable-gif", action="store_true", diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py index 87ae130..eb90430 100644 --- a/lama_cleaner/plugins/__init__.py +++ b/lama_cleaner/plugins/__init__.py @@ -1,4 +1,5 @@ from .interactive_seg import InteractiveSeg, Click from .remove_bg import RemoveBG from .realesrgan import RealESRGANUpscaler +from .gfpgan_plugin import GFPGANPlugin from .gif import MakeGIF diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py new file mode 100644 index 0000000..0d465e0 --- /dev/null +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -0,0 +1,61 @@ +import cv2 +from loguru import logger + +from lama_cleaner.helper import download_model +from lama_cleaner.plugins.base_plugin import BasePlugin + + +class GFPGANPlugin(BasePlugin): + name = "GFPGAN" + + def __init__(self, device): + super().__init__() + from .gfpganer import MyGFPGANer + + url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" + model_md5 = "94d735072630ab734561130a47bc44f8" + model_path = download_model(url, model_md5) + logger.info(f"GFPGAN model path: {model_path}") + + # Use GFPGAN for face enhancement + self.face_enhancer = MyGFPGANer( + model_path=model_path, + upscale=2, + arch="clean", + channel_multiplier=2, + device=device, + ) + + def __call__(self, rgb_np_img, files, form): + weight = 0.5 + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + logger.info(f"GFPGAN input shape: {bgr_np_img.shape}") + _, _, bgr_output = self.face_enhancer.enhance( + bgr_np_img, + has_aligned=False, + only_center_face=False, + paste_back=True, + weight=weight, + ) + logger.info(f"GFPGAN output shape: {bgr_output.shape}") + + # try: + # if scale != 2: + # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 + # h, w = img.shape[0:2] + # output = cv2.resize( + # output, + # (int(w * scale / 2), int(h * scale / 2)), + # interpolation=interpolation, + # ) + # except Exception as error: + # print("wrong scale input.", error) + return bgr_output + + def check_dep(self): + try: + import gfpgan + except ImportError: + return ( + "gfpgan is not installed, please install it first. pip install gfpgan" + ) diff --git a/lama_cleaner/plugins/gfpganer.py b/lama_cleaner/plugins/gfpganer.py new file mode 100644 index 0000000..75a575d --- /dev/null +++ b/lama_cleaner/plugins/gfpganer.py @@ -0,0 +1,84 @@ +import os + +import torch +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from gfpgan import GFPGANv1Clean, GFPGANer +from torch.hub import get_dir + + +class MyGFPGANer(GFPGANer): + """Helper for restoration with GFPGAN. + + It will detect and crop faces, and then resize the faces to 512x512. + GFPGAN is used to restored the resized faces. + The background is upsampled with the bg_upsampler. + Finally, the faces will be pasted back to the upsample background image. + + Args: + model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). + upscale (float): The upscale of the final output. Default: 2. + arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + bg_upsampler (nn.Module): The upsampler for the background. Default: None. + """ + + def __init__( + self, + model_path, + upscale=2, + arch="clean", + channel_multiplier=2, + bg_upsampler=None, + device=None, + ): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + # initialize the GFP-GAN + if arch == "clean": + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True, + ) + elif arch == "RestoreFormer": + from gfpgan.archs.restoreformer_arch import RestoreFormer + + self.gfpgan = RestoreFormer() + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + use_parse=True, + device=self.device, + model_rootpath=model_dir, + ) + + loadnet = torch.load(model_path) + if "params_ema" in loadnet: + keyname = "params_ema" + else: + keyname = "params" + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py index 6876b82..a408b51 100644 --- a/lama_cleaner/plugins/realesrgan.py +++ b/lama_cleaner/plugins/realesrgan.py @@ -71,6 +71,7 @@ class RealESRGANUpscaler(BasePlugin): model_info = REAL_ESRGAN_MODELS[name] model_path = download_model(model_info["url"], model_info["model_md5"]) + logger.info(f"RealESRGAN model path: {model_path}") self.model = RealESRGANer( scale=model_info["scale"], diff --git a/lama_cleaner/plugins/remove_bg.py b/lama_cleaner/plugins/remove_bg.py index 15dfc5b..37ce41b 100644 --- a/lama_cleaner/plugins/remove_bg.py +++ b/lama_cleaner/plugins/remove_bg.py @@ -1,5 +1,7 @@ +import os import cv2 import numpy as np +from torch.hub import get_dir from lama_cleaner.plugins.base_plugin import BasePlugin @@ -11,6 +13,10 @@ class RemoveBG(BasePlugin): super().__init__() from rembg import new_session + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + os.environ["U2NET_HOME"] = model_dir + self.session = new_session(model_name="u2net") def __call__(self, rgb_np_img, files, form): diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 7216410..99b5958 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -19,7 +19,13 @@ from lama_cleaner.const import SD15_MODELS from lama_cleaner.file_manager import FileManager from lama_cleaner.model.utils import torch_gc from lama_cleaner.model_manager import ModelManager -from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler, MakeGIF +from lama_cleaner.plugins import ( + InteractiveSeg, + RemoveBG, + RealESRGANUpscaler, + MakeGIF, + GFPGANPlugin, +) from lama_cleaner.schema import Config try: @@ -423,6 +429,9 @@ def build_plugins(args): plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( args.realesrgan_model, args.realesrgan_device ) + if args.enable_gfpgan: + logger.info(f"Initialize {GFPGANPlugin.name} plugin") + plugins[GFPGANPlugin.name] = GFPGANPlugin(args.gfpgan_device) if args.enable_gif: logger.info(f"Initialize GIF plugin") plugins[MakeGIF.name] = MakeGIF()