add restoreformer
This commit is contained in:
parent
f2e90d3f84
commit
c52f733214
@ -616,6 +616,15 @@ export default function Editor() {
|
||||
}
|
||||
}, [runRenderablePlugin])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.RestoreFormer, () => {
|
||||
runRenderablePlugin(PluginName.RestoreFormer)
|
||||
})
|
||||
return () => {
|
||||
emitter.off(PluginName.RestoreFormer)
|
||||
}
|
||||
}, [runRenderablePlugin])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.RealESRGAN, (data: any) => {
|
||||
runRenderablePlugin(PluginName.RealESRGAN, data)
|
||||
|
@ -48,7 +48,7 @@
|
||||
border-radius: 3px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
height: 25px;
|
||||
height: 32px;
|
||||
padding: 0 5px;
|
||||
position: relative;
|
||||
user-select: none;
|
||||
|
@ -22,6 +22,7 @@ export enum PluginName {
|
||||
RemoveBG = 'RemoveBG',
|
||||
RealESRGAN = 'RealESRGAN',
|
||||
GFPGAN = 'GFPGAN',
|
||||
RestoreFormer = 'RestoreFormer',
|
||||
InteractiveSeg = 'InteractiveSeg',
|
||||
MakeGIF = 'MakeGIF',
|
||||
}
|
||||
@ -39,6 +40,10 @@ const pluginMap = {
|
||||
IconClass: FaceIcon,
|
||||
showName: 'GFPGAN',
|
||||
},
|
||||
[PluginName.RestoreFormer]: {
|
||||
IconClass: FaceIcon,
|
||||
showName: 'RestoreFormer',
|
||||
},
|
||||
[PluginName.InteractiveSeg]: {
|
||||
IconClass: CursorArrowRaysIcon,
|
||||
showName: 'Interactive Seg',
|
||||
|
@ -109,4 +109,6 @@ GFPGAN_HELP = (
|
||||
"Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan"
|
||||
)
|
||||
GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
||||
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan"
|
||||
RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
||||
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
||||
|
@ -106,7 +106,16 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--enable-gfpgan", action="store_true", help=GFPGAN_HELP)
|
||||
parser.add_argument(
|
||||
"--gfpgan-device", default="cpu", type=str, choices=["cpu", "cuda", "mps"]
|
||||
"--gfpgan-device", default="cpu", type=str, choices=GFPGAN_AVAILABLE_DEVICES
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-restoreformer", action="store_true", help=RESTOREFORMER_HELP
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restoreformer-device",
|
||||
default="cpu",
|
||||
type=str,
|
||||
choices=RESTOREFORMER_AVAILABLE_DEVICES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-gif",
|
||||
|
@ -2,4 +2,5 @@ from .interactive_seg import InteractiveSeg, Click
|
||||
from .remove_bg import RemoveBG
|
||||
from .realesrgan import RealESRGANUpscaler
|
||||
from .gfpgan_plugin import GFPGANPlugin
|
||||
from .restoreformer import RestoreFormerPlugin
|
||||
from .gif import MakeGIF
|
||||
|
54
lama_cleaner/plugins/restoreformer.py
Normal file
54
lama_cleaner/plugins/restoreformer.py
Normal file
@ -0,0 +1,54 @@
|
||||
import cv2
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
|
||||
|
||||
class RestoreFormerPlugin(BasePlugin):
|
||||
name = "RestoreFormer"
|
||||
|
||||
def __init__(self, device, upscaler=None):
|
||||
super().__init__()
|
||||
from .gfpganer import MyGFPGANer
|
||||
|
||||
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
|
||||
model_md5 = "eaeeff6c4a1caa1673977cb374e6f699"
|
||||
model_path = download_model(url, model_md5)
|
||||
logger.info(f"RestoreFormer model path: {model_path}")
|
||||
|
||||
import facexlib
|
||||
|
||||
if hasattr(facexlib.detection.retinaface, "device"):
|
||||
facexlib.detection.retinaface.device = device
|
||||
|
||||
self.face_enhancer = MyGFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=2,
|
||||
arch="RestoreFormer",
|
||||
channel_multiplier=2,
|
||||
device=device,
|
||||
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
weight = 0.5
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
||||
_, _, bgr_output = self.face_enhancer.enhance(
|
||||
bgr_np_img,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
weight=weight,
|
||||
)
|
||||
logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
|
||||
return bgr_output
|
||||
|
||||
def check_dep(self):
|
||||
try:
|
||||
import gfpgan
|
||||
except ImportError:
|
||||
return (
|
||||
"gfpgan is not installed, please install it first. pip install gfpgan"
|
||||
)
|
@ -27,6 +27,7 @@ from lama_cleaner.plugins import (
|
||||
RealESRGANUpscaler,
|
||||
MakeGIF,
|
||||
GFPGANPlugin,
|
||||
RestoreFormerPlugin
|
||||
)
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@ -113,25 +114,6 @@ def diffuser_callback(i, t, latents):
|
||||
# socketio.emit('diffusion_step', {'diffusion_step': step})
|
||||
|
||||
|
||||
@app.route("/make_gif", methods=["POST"])
|
||||
def make_gif():
|
||||
input = request.files
|
||||
filename = request.form["filename"]
|
||||
origin_image_bytes = input["origin_img"].read()
|
||||
clean_image_bytes = input["clean_img"].read()
|
||||
origin_image, _ = load_img(origin_image_bytes)
|
||||
clean_image, _ = load_img(clean_image_bytes)
|
||||
gif_bytes = make_compare_gif(
|
||||
Image.fromarray(origin_image), Image.fromarray(clean_image)
|
||||
)
|
||||
return send_file(
|
||||
io.BytesIO(gif_bytes),
|
||||
mimetype="image/gif",
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/save_image", methods=["POST"])
|
||||
def save_image():
|
||||
if output_dir is None:
|
||||
@ -461,6 +443,11 @@ def build_plugins(args):
|
||||
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
|
||||
)
|
||||
if args.enable_restoreformer:
|
||||
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
|
||||
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
||||
args.restoreformer_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
|
||||
)
|
||||
if args.enable_gif:
|
||||
logger.info(f"Initialize GIF plugin")
|
||||
plugins[MakeGIF.name] = MakeGIF()
|
||||
|
@ -36,6 +36,8 @@ class Config(BaseModel):
|
||||
realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value
|
||||
enable_gfpgan: bool = False
|
||||
gfpgan_device: str = "cpu"
|
||||
enable_restoreformer: bool = False
|
||||
restoreformer_device: str = "cpu"
|
||||
enable_gif: bool = False
|
||||
|
||||
|
||||
@ -72,6 +74,8 @@ def save_config(
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
enable_gif,
|
||||
):
|
||||
config = Config(**locals())
|
||||
@ -179,6 +183,15 @@ def main(config_file: str):
|
||||
label="GFPGAN Device",
|
||||
value=init_config.gfpgan_device,
|
||||
)
|
||||
with gr.Row():
|
||||
enable_restoreformer = gr.Checkbox(
|
||||
init_config.enable_restoreformer, label=RESTOREFORMER_HELP
|
||||
)
|
||||
restoreformer_device = gr.Radio(
|
||||
RESTOREFORMER_AVAILABLE_DEVICES,
|
||||
label="RestoreFormer Device",
|
||||
value=init_config.restoreformer_device,
|
||||
)
|
||||
enable_gif = gr.Checkbox(init_config.enable_gif, label=GIF_HELP)
|
||||
|
||||
with gr.Tab("Diffusion Model"):
|
||||
@ -229,6 +242,8 @@ def main(config_file: str):
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
enable_gif,
|
||||
],
|
||||
message,
|
||||
|
Loading…
Reference in New Issue
Block a user