pass upscaler to GFPGAN
This commit is contained in:
parent
b200db920b
commit
4d1809e908
@ -8,7 +8,7 @@ from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
class GFPGANPlugin(BasePlugin):
|
||||
name = "GFPGAN"
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device, upscaler=None):
|
||||
super().__init__()
|
||||
from .gfpganer import MyGFPGANer
|
||||
|
||||
@ -24,6 +24,7 @@ class GFPGANPlugin(BasePlugin):
|
||||
arch="clean",
|
||||
channel_multiplier=2,
|
||||
device=device,
|
||||
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
|
@ -449,7 +449,9 @@ def build_plugins(args):
|
||||
)
|
||||
if args.enable_gfpgan:
|
||||
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
||||
plugins[GFPGANPlugin.name] = GFPGANPlugin(args.gfpgan_device)
|
||||
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
|
||||
)
|
||||
if args.enable_gif:
|
||||
logger.info(f"Initialize GIF plugin")
|
||||
plugins[MakeGIF.name] = MakeGIF()
|
||||
|
Loading…
Reference in New Issue
Block a user