diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py index 0d465e0..cee1834 100644 --- a/lama_cleaner/plugins/gfpgan_plugin.py +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -8,7 +8,7 @@ from lama_cleaner.plugins.base_plugin import BasePlugin class GFPGANPlugin(BasePlugin): name = "GFPGAN" - def __init__(self, device): + def __init__(self, device, upscaler=None): super().__init__() from .gfpganer import MyGFPGANer @@ -24,6 +24,7 @@ class GFPGANPlugin(BasePlugin): arch="clean", channel_multiplier=2, device=device, + bg_upsampler=upscaler.model if upscaler is not None else None, ) def __call__(self, rgb_np_img, files, form): diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index cb79b8e..8ea10b9 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -449,7 +449,9 @@ def build_plugins(args): ) if args.enable_gfpgan: logger.info(f"Initialize {GFPGANPlugin.name} plugin") - plugins[GFPGANPlugin.name] = GFPGANPlugin(args.gfpgan_device) + plugins[GFPGANPlugin.name] = GFPGANPlugin( + args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None) + ) if args.enable_gif: logger.info(f"Initialize GIF plugin") plugins[MakeGIF.name] = MakeGIF()