pass upscaler to GFPGAN

This commit is contained in:
Qing 2023-03-26 20:52:06 +08:00
parent b200db920b
commit 4d1809e908
2 changed files with 5 additions and 2 deletions

View File

@ -8,7 +8,7 @@ from lama_cleaner.plugins.base_plugin import BasePlugin
class GFPGANPlugin(BasePlugin):
name = "GFPGAN"
def __init__(self, device):
def __init__(self, device, upscaler=None):
super().__init__()
from .gfpganer import MyGFPGANer
@ -24,6 +24,7 @@ class GFPGANPlugin(BasePlugin):
arch="clean",
channel_multiplier=2,
device=device,
bg_upsampler=upscaler.model if upscaler is not None else None,
)
def __call__(self, rgb_np_img, files, form):

View File

@ -449,7 +449,9 @@ def build_plugins(args):
)
if args.enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
plugins[GFPGANPlugin.name] = GFPGANPlugin(args.gfpgan_device)
plugins[GFPGANPlugin.name] = GFPGANPlugin(
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
)
if args.enable_gif:
logger.info(f"Initialize GIF plugin")
plugins[MakeGIF.name] = MakeGIF()