add --realesrgan-no-half

This commit is contained in:
Qing 2023-04-03 13:32:04 +08:00
parent dd1d45aa79
commit 03206fb8d6
3 changed files with 10 additions and 3 deletions

View File

@ -104,6 +104,11 @@ def parse_args():
type=str, type=str,
choices=RealESRGANModelNameList, choices=RealESRGANModelNameList,
) )
parser.add_argument(
"--realesrgan-no-half",
action="store_true",
help="Disable half precision for RealESRGAN",
)
parser.add_argument("--enable-gfpgan", action="store_true", help=GFPGAN_HELP) parser.add_argument("--enable-gfpgan", action="store_true", help=GFPGAN_HELP)
parser.add_argument( parser.add_argument(
"--gfpgan-device", default="cpu", type=str, choices=GFPGAN_AVAILABLE_DEVICES "--gfpgan-device", default="cpu", type=str, choices=GFPGAN_AVAILABLE_DEVICES

View File

@ -11,7 +11,7 @@ from lama_cleaner.plugins.base_plugin import BasePlugin
class RealESRGANUpscaler(BasePlugin): class RealESRGANUpscaler(BasePlugin):
name = "RealESRGAN" name = "RealESRGAN"
def __init__(self, name, device): def __init__(self, name, device, no_half=False):
super().__init__() super().__init__()
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
@ -69,7 +69,7 @@ class RealESRGANUpscaler(BasePlugin):
scale=model_info["scale"], scale=model_info["scale"],
model_path=model_path, model_path=model_path,
model=model_info["model"](), model=model_info["model"](),
half=True if "cuda" in str(device) else False, half=True if "cuda" in str(device) and not no_half else False,
tile=512, tile=512,
tile_pad=10, tile_pad=10,
pre_pad=10, pre_pad=10,

View File

@ -446,7 +446,9 @@ def build_plugins(args):
f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}" f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}"
) )
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
args.realesrgan_model, args.realesrgan_device args.realesrgan_model,
args.realesrgan_device,
no_half=args.realesrgan_no_half,
) )
if args.enable_gfpgan: if args.enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin") logger.info(f"Initialize {GFPGANPlugin.name} plugin")