IOPaint/iopaint/plugins/realesrgan.py

110 lines
3.9 KiB
Python
Raw Normal View History

2023-03-22 05:57:18 +01:00
import cv2
import numpy as np
import torch
from loguru import logger
2023-03-22 05:57:18 +01:00
2024-01-05 08:19:23 +01:00
from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
2024-01-31 14:51:34 +01:00
from iopaint.schema import RunPluginRequest, RealESRGANModel
2023-03-22 05:57:18 +01:00
2023-03-26 06:37:58 +02:00
class RealESRGANUpscaler(BasePlugin):
2023-03-22 05:57:18 +01:00
name = "RealESRGAN"
support_gen_image = True
2023-03-22 05:57:18 +01:00
2023-04-03 07:32:04 +02:00
def __init__(self, name, device, no_half=False):
2023-03-22 05:57:18 +01:00
super().__init__()
2024-02-08 10:16:57 +01:00
self.model_name = name
self.device = device
self.no_half = no_half
self._init_model(name)
def _init_model(self, name):
2023-03-22 05:57:18 +01:00
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
2023-03-25 03:15:44 +01:00
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
2023-03-22 05:57:18 +01:00
2023-03-25 03:15:44 +01:00
REAL_ESRGAN_MODELS = {
2023-12-24 08:32:27 +01:00
RealESRGANModel.realesr_general_x4v3: {
2023-03-25 03:15:44 +01:00
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
"scale": 4,
"model": lambda: SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
),
"model_md5": "91a7644643c884ee00737db24e478156",
},
2023-12-24 08:32:27 +01:00
RealESRGANModel.RealESRGAN_x4plus: {
2023-03-25 03:15:44 +01:00
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"scale": 4,
"model": lambda: RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
),
"model_md5": "99ec365d4afad750833258a1a24f44ca",
},
2023-12-24 08:32:27 +01:00
RealESRGANModel.RealESRGAN_x4plus_anime_6B: {
2023-03-25 03:15:44 +01:00
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"scale": 4,
"model": lambda: RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6,
num_grow_ch=32,
scale=4,
),
"model_md5": "d58ce384064ec1591c2ea7b79dbf47ba",
},
}
if name not in REAL_ESRGAN_MODELS:
raise ValueError(f"Unknown RealESRGAN model name: {name}")
model_info = REAL_ESRGAN_MODELS[name]
model_path = download_model(model_info["url"], model_info["model_md5"])
2023-03-26 07:39:09 +02:00
logger.info(f"RealESRGAN model path: {model_path}")
2023-03-22 05:57:18 +01:00
self.model = RealESRGANer(
2023-03-25 03:15:44 +01:00
scale=model_info["scale"],
2023-03-22 05:57:18 +01:00
model_path=model_path,
2023-03-25 03:15:44 +01:00
model=model_info["model"](),
2024-02-08 10:16:57 +01:00
half=True if "cuda" in str(self.device) and not self.no_half else False,
tile=512,
2023-03-22 05:57:18 +01:00
tile_pad=10,
pre_pad=10,
2024-02-08 10:16:57 +01:00
device=self.device,
2023-03-22 05:57:18 +01:00
)
2024-02-08 10:16:57 +01:00
def switch_model(self, new_model_name: str):
if self.model_name == new_model_name:
return
self._init_model(new_model_name)
self.model_name = new_model_name
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
2023-03-22 05:57:18 +01:00
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
2024-01-02 04:07:35 +01:00
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
result = self.forward(bgr_np_img, req.scale)
logger.info(f"RealESRGAN output shape: {result.shape}")
return result
2023-03-22 05:57:18 +01:00
@torch.inference_mode()
2023-03-22 05:57:18 +01:00
def forward(self, bgr_np_img, scale: float):
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
return upsampled
2023-03-26 06:37:58 +02:00
def check_dep(self):
try:
import realesrgan
except ImportError:
return "RealESRGAN is not installed, please install it first. pip install realesrgan"