import math import cv2 import numpy as np import torch from torch import nn import torch.nn.functional as F from loguru import logger from iopaint.helper import download_model from iopaint.plugins.base_plugin import BasePlugin from iopaint.schema import RunPluginRequest, RealESRGANModel class RealESRGANer: """A helper class for upsampling images with RealESRGAN. Args: scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). model (nn.Module): The defined network. Default: None. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop input images into tiles, and then process each of them. Finally, they will be merged into one image. 0 denotes for do not use tile. Default: 0. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. half (float): Whether to use half precision during inference. Default: False. """ def __init__( self, scale, model_path, dni_weight=None, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None, gpu_id=None, ): self.scale = scale self.tile_size = tile self.tile_pad = tile_pad self.pre_pad = pre_pad self.mod_scale = None self.half = half # initialize model if gpu_id: self.device = ( torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") if device is None else device ) else: self.device = ( torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device ) if isinstance(model_path, list): # dni assert len(model_path) == len( dni_weight ), "model_path and dni_weight should have the save length." loadnet = self.dni(model_path[0], model_path[1], dni_weight) else: # if the model_path starts with https, it will first download models to the folder: weights loadnet = torch.load(model_path, map_location=torch.device("cpu")) # prefer to use params_ema if "params_ema" in loadnet: keyname = "params_ema" else: keyname = "params" model.load_state_dict(loadnet[keyname], strict=True) model.eval() self.model = model.to(self.device) if self.half: self.model = self.model.half() def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"): """Deep network interpolation. ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition`` """ net_a = torch.load(net_a, map_location=torch.device(loc)) net_b = torch.load(net_b, map_location=torch.device(loc)) for k, v_a in net_a[key].items(): net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k] return net_a def pre_process(self, img): """Pre-process, such as pre-pad and mod pad, so that the images can be divisible""" img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() self.img = img.unsqueeze(0).to(self.device) if self.half: self.img = self.img.half() # pre_pad if self.pre_pad != 0: self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect") # mod pad for divisible borders if self.scale == 2: self.mod_scale = 2 elif self.scale == 1: self.mod_scale = 4 if self.mod_scale is not None: self.mod_pad_h, self.mod_pad_w = 0, 0 _, _, h, w = self.img.size() if h % self.mod_scale != 0: self.mod_pad_h = self.mod_scale - h % self.mod_scale if w % self.mod_scale != 0: self.mod_pad_w = self.mod_scale - w % self.mod_scale self.img = F.pad( self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect" ) def process(self): # model inference self.output = self.model(self.img) def tile_process(self): """It will first crop input images to tiles, and then process each tile. Finally, all the processed tiles are merged into one images. Modified from: https://github.com/ata4/esrgan-launcher """ batch, channel, height, width = self.img.shape output_height = height * self.scale output_width = width * self.scale output_shape = (batch, channel, output_height, output_width) # start with black image self.output = self.img.new_zeros(output_shape) tiles_x = math.ceil(width / self.tile_size) tiles_y = math.ceil(height / self.tile_size) # loop over all tiles for y in range(tiles_y): for x in range(tiles_x): # extract tile from input image ofs_x = x * self.tile_size ofs_y = y * self.tile_size # input tile area on total image input_start_x = ofs_x input_end_x = min(ofs_x + self.tile_size, width) input_start_y = ofs_y input_end_y = min(ofs_y + self.tile_size, height) # input tile area on total image with padding input_start_x_pad = max(input_start_x - self.tile_pad, 0) input_end_x_pad = min(input_end_x + self.tile_pad, width) input_start_y_pad = max(input_start_y - self.tile_pad, 0) input_end_y_pad = min(input_end_y + self.tile_pad, height) # input tile dimensions input_tile_width = input_end_x - input_start_x input_tile_height = input_end_y - input_start_y tile_idx = y * tiles_x + x + 1 input_tile = self.img[ :, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad, ] # upscale tile try: with torch.no_grad(): output_tile = self.model(input_tile) except RuntimeError as error: print("Error", error) print(f"\tTile {tile_idx}/{tiles_x * tiles_y}") # output tile area on total image output_start_x = input_start_x * self.scale output_end_x = input_end_x * self.scale output_start_y = input_start_y * self.scale output_end_y = input_end_y * self.scale # output tile area without padding output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale output_end_x_tile = output_start_x_tile + input_tile_width * self.scale output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale output_end_y_tile = output_start_y_tile + input_tile_height * self.scale # put tile into output image self.output[ :, :, output_start_y:output_end_y, output_start_x:output_end_x ] = output_tile[ :, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile, ] def post_process(self): # remove extra pad if self.mod_scale is not None: _, _, h, w = self.output.size() self.output = self.output[ :, :, 0 : h - self.mod_pad_h * self.scale, 0 : w - self.mod_pad_w * self.scale, ] # remove prepad if self.pre_pad != 0: _, _, h, w = self.output.size() self.output = self.output[ :, :, 0 : h - self.pre_pad * self.scale, 0 : w - self.pre_pad * self.scale, ] return self.output @torch.no_grad() def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"): h_input, w_input = img.shape[0:2] # img: numpy img = img.astype(np.float32) if np.max(img) > 256: # 16-bit image max_range = 65535 print("\tInput is a 16-bit image") else: max_range = 255 img = img / max_range if len(img.shape) == 2: # gray image img_mode = "L" img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) elif img.shape[2] == 4: # RGBA image with alpha channel img_mode = "RGBA" alpha = img[:, :, 3] img = img[:, :, 0:3] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if alpha_upsampler == "realesrgan": alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) else: img_mode = "RGB" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # ------------------- process image (without the alpha channel) ------------------- # self.pre_process(img) if self.tile_size > 0: self.tile_process() else: self.process() output_img = self.post_process() output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) if img_mode == "L": output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) # ------------------- process the alpha channel if necessary ------------------- # if img_mode == "RGBA": if alpha_upsampler == "realesrgan": self.pre_process(alpha) if self.tile_size > 0: self.tile_process() else: self.process() output_alpha = self.post_process() output_alpha = ( output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() ) output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) else: # use the cv2 resize for alpha channel h, w = alpha.shape[0:2] output_alpha = cv2.resize( alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR, ) # merge the alpha channel output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) output_img[:, :, 3] = output_alpha # ------------------------------ return ------------------------------ # if max_range == 65535: # 16-bit image output = (output_img * 65535.0).round().astype(np.uint16) else: output = (output_img * 255.0).round().astype(np.uint8) if outscale is not None and outscale != float(self.scale): output = cv2.resize( output, ( int(w_input * outscale), int(h_input * outscale), ), interpolation=cv2.INTER_LANCZOS4, ) return output, img_mode class SRVGGNetCompact(nn.Module): """A compact VGG-style network structure for super-resolution. It is a compact network structure, which performs upsampling in the last layer and no convolution is conducted on the HR feature space. Args: num_in_ch (int): Channel number of inputs. Default: 3. num_out_ch (int): Channel number of outputs. Default: 3. num_feat (int): Channel number of intermediate features. Default: 64. num_conv (int): Number of convolution layers in the body network. Default: 16. upscale (int): Upsampling factor. Default: 4. act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. """ def __init__( self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type="prelu", ): super(SRVGGNetCompact, self).__init__() self.num_in_ch = num_in_ch self.num_out_ch = num_out_ch self.num_feat = num_feat self.num_conv = num_conv self.upscale = upscale self.act_type = act_type self.body = nn.ModuleList() # the first conv self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) # the first activation if act_type == "relu": activation = nn.ReLU(inplace=True) elif act_type == "prelu": activation = nn.PReLU(num_parameters=num_feat) elif act_type == "leakyrelu": activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.body.append(activation) # the body structure for _ in range(num_conv): self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) # activation if act_type == "relu": activation = nn.ReLU(inplace=True) elif act_type == "prelu": activation = nn.PReLU(num_parameters=num_feat) elif act_type == "leakyrelu": activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.body.append(activation) # the last conv self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) # upsample self.upsampler = nn.PixelShuffle(upscale) def forward(self, x): out = x for i in range(0, len(self.body)): out = self.body[i](out) out = self.upsampler(out) # add the nearest upsampled image, so that the network learns the residual base = F.interpolate(x, scale_factor=self.upscale, mode="nearest") out += base return out class RealESRGANUpscaler(BasePlugin): name = "RealESRGAN" support_gen_image = True def __init__(self, name, device, no_half=False): super().__init__() self.model_name = name self.device = device self.no_half = no_half self._init_model(name) def _init_model(self, name): from .basicsr import RRDBNet REAL_ESRGAN_MODELS = { RealESRGANModel.realesr_general_x4v3: { "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "scale": 4, "model": lambda: SRVGGNetCompact( num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type="prelu", ), "model_md5": "91a7644643c884ee00737db24e478156", }, RealESRGANModel.RealESRGAN_x4plus: { "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", "scale": 4, "model": lambda: RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ), "model_md5": "99ec365d4afad750833258a1a24f44ca", }, RealESRGANModel.RealESRGAN_x4plus_anime_6B: { "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", "scale": 4, "model": lambda: RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4, ), "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba", }, } if name not in REAL_ESRGAN_MODELS: raise ValueError(f"Unknown RealESRGAN model name: {name}") model_info = REAL_ESRGAN_MODELS[name] model_path = download_model(model_info["url"], model_info["model_md5"]) logger.info(f"RealESRGAN model path: {model_path}") self.model = RealESRGANer( scale=model_info["scale"], model_path=model_path, model=model_info["model"](), half=True if "cuda" in str(self.device) and not self.no_half else False, tile=512, tile_pad=10, pre_pad=10, device=self.device, ) def switch_model(self, new_model_name: str): if self.model_name == new_model_name: return self._init_model(new_model_name) self.model_name = new_model_name def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") result = self.forward(bgr_np_img, req.scale) logger.info(f"RealESRGAN output shape: {result.shape}") return result @torch.inference_mode() def forward(self, bgr_np_img, scale: float): # 输出是 BGR upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] return upsampled