From ffdf5e06e16ddafdb3d964f83ad1448d99de915d Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 12 Aug 2024 10:34:54 +0800 Subject: [PATCH] remove realesrgan dep --- iopaint/installer.py | 1 - iopaint/plugins/basicsr/LICENSE | 201 +++++++++++++ iopaint/plugins/basicsr/__init__.py | 22 ++ iopaint/plugins/basicsr/arch_util.py | 80 +++++ iopaint/plugins/basicsr/rrdbnet_arch.py | 133 +++++++++ iopaint/plugins/realesrgan.py | 376 +++++++++++++++++++++++- iopaint/runtime.py | 1 - 7 files changed, 805 insertions(+), 9 deletions(-) create mode 100644 iopaint/plugins/basicsr/LICENSE create mode 100644 iopaint/plugins/basicsr/__init__.py create mode 100644 iopaint/plugins/basicsr/arch_util.py create mode 100644 iopaint/plugins/basicsr/rrdbnet_arch.py diff --git a/iopaint/installer.py b/iopaint/installer.py index f255e33..71894e0 100644 --- a/iopaint/installer.py +++ b/iopaint/installer.py @@ -8,5 +8,4 @@ def install(package): def install_plugins_package(): install("rembg") - install("realesrgan") install("gfpgan") diff --git a/iopaint/plugins/basicsr/LICENSE b/iopaint/plugins/basicsr/LICENSE new file mode 100644 index 0000000..1c9b5b8 --- /dev/null +++ b/iopaint/plugins/basicsr/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018-2022 BasicSR Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/iopaint/plugins/basicsr/__init__.py b/iopaint/plugins/basicsr/__init__.py new file mode 100644 index 0000000..6bd8efd --- /dev/null +++ b/iopaint/plugins/basicsr/__init__.py @@ -0,0 +1,22 @@ +""" +Adapted from https://github.com/XPixelGroup/BasicSR +License: Apache-2.0 + +As of Feb 2024, `basicsr` appears to be unmaintained. It imports a function from `torchvision` that is removed in +`torchvision` 0.17. Here is the deprecation warning: + + UserWarning: The torchvision.transforms.functional_tensor module is deprecated in 0.15 and will be **removed in + 0.17**. Please don't rely on it. You probably just need to use APIs in torchvision.transforms.functional or in + torchvision.transforms.v2.functional. + +As a result, a dependency on `basicsr` means we cannot keep our `torchvision` dependency up to date. + +Because we only rely on a single class `RRDBNet` from `basicsr`, we've copied the relevant code here and removed the +dependency on `basicsr`. + +The code is almost unchanged, only a few type annotations have been added. The license is also copied. + +Copy From InvokeAI +""" + +from .rrdbnet_arch import RRDBNet diff --git a/iopaint/plugins/basicsr/arch_util.py b/iopaint/plugins/basicsr/arch_util.py new file mode 100644 index 0000000..befe76a --- /dev/null +++ b/iopaint/plugins/basicsr/arch_util.py @@ -0,0 +1,80 @@ +from typing import Type, List, Union + +import torch +from torch import nn as nn +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + + +@torch.no_grad() +def default_init_weights( + module_list: Union[List[nn.Module], nn.Module], + scale: float = 1, + bias_fill: float = 0, + **kwargs, +) -> None: + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer( + basic_block: Type[nn.Module], num_basic_block: int, **kwarg +) -> nn.Sequential: + """Make layers by stacking the same blocks. + + Args: + basic_block (Type[nn.Module]): nn.Module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +# TODO: may write a cpp file +def pixel_unshuffle(x: torch.Tensor, scale: int) -> torch.Tensor: + """Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) diff --git a/iopaint/plugins/basicsr/rrdbnet_arch.py b/iopaint/plugins/basicsr/rrdbnet_arch.py new file mode 100644 index 0000000..31c08eb --- /dev/null +++ b/iopaint/plugins/basicsr/rrdbnet_arch.py @@ -0,0 +1,133 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat: int = 64, num_grow_ch: int = 32) -> None: + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights( + [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Empirically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat: int, num_grow_ch: int = 32) -> None: + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__( + self, + num_in_ch: int, + num_out_ch: int, + scale: int = 4, + num_feat: int = 64, + num_block: int = 23, + num_grow_ch: int = 32, + ) -> None: + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer( + RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch + ) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) + ) + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) + ) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out diff --git a/iopaint/plugins/realesrgan.py b/iopaint/plugins/realesrgan.py index 5275700..dbcbb80 100644 --- a/iopaint/plugins/realesrgan.py +++ b/iopaint/plugins/realesrgan.py @@ -1,6 +1,10 @@ +import math + import cv2 import numpy as np import torch +from torch import nn +import torch.nn.functional as F from loguru import logger from iopaint.helper import download_model @@ -8,6 +12,369 @@ from iopaint.plugins.base_plugin import BasePlugin from iopaint.schema import RunPluginRequest, RealESRGANModel +class RealESRGANer: + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__( + self, + scale, + model_path, + dni_weight=None, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None, + ): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + if gpu_id: + self.device = ( + torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + else: + self.device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + + if isinstance(model_path, list): + # dni + assert len(model_path) == len( + dni_weight + ), "model_path and dni_weight should have the save length." + loadnet = self.dni(model_path[0], model_path[1], dni_weight) + else: + # if the model_path starts with https, it will first download models to the folder: weights + loadnet = torch.load(model_path, map_location=torch.device("cpu")) + + # prefer to use params_ema + if "params_ema" in loadnet: + keyname = "params_ema" + else: + keyname = "params" + model.load_state_dict(loadnet[keyname], strict=True) + + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"): + """Deep network interpolation. + + ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition`` + """ + net_a = torch.load(net_a, map_location=torch.device(loc)) + net_b = torch.load(net_b, map_location=torch.device(loc)) + for k, v_a in net_a[key].items(): + net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k] + return net_a + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible""" + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect") + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if h % self.mod_scale != 0: + self.mod_pad_h = self.mod_scale - h % self.mod_scale + if w % self.mod_scale != 0: + self.mod_pad_w = self.mod_scale - w % self.mod_scale + self.img = F.pad( + self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect" + ) + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[ + :, + :, + input_start_y_pad:input_end_y_pad, + input_start_x_pad:input_end_x_pad, + ] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print("Error", error) + print(f"\tTile {tile_idx}/{tiles_x * tiles_y}") + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[ + :, :, output_start_y:output_end_y, output_start_x:output_end_x + ] = output_tile[ + :, + :, + output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile, + ] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[ + :, + :, + 0 : h - self.mod_pad_h * self.scale, + 0 : w - self.mod_pad_w * self.scale, + ] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[ + :, + :, + 0 : h - self.pre_pad * self.scale, + 0 : w - self.pre_pad * self.scale, + ] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print("\tInput is a 16-bit image") + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = "L" + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = "RGBA" + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == "realesrgan": + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = "RGB" + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == "L": + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == "RGBA": + if alpha_upsampler == "realesrgan": + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = ( + output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + ) + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize( + alpha, + (w * self.scale, h * self.scale), + interpolation=cv2.INTER_LINEAR, + ) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, + ( + int(w_input * outscale), + int(h_input * outscale), + ), + interpolation=cv2.INTER_LANCZOS4, + ) + + return output, img_mode + + +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__( + self, + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=16, + upscale=4, + act_type="prelu", + ): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == "relu": + activation = nn.ReLU(inplace=True) + elif act_type == "prelu": + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == "leakyrelu": + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == "relu": + activation = nn.ReLU(inplace=True) + elif act_type == "prelu": + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == "leakyrelu": + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode="nearest") + out += base + return out + + class RealESRGANUpscaler(BasePlugin): name = "RealESRGAN" support_gen_image = True @@ -20,9 +387,7 @@ class RealESRGANUpscaler(BasePlugin): self._init_model(name) def _init_model(self, name): - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer - from realesrgan.archs.srvgg_arch import SRVGGNetCompact + from .basicsr import RRDBNet REAL_ESRGAN_MODELS = { RealESRGANModel.realesr_general_x4v3: { @@ -103,7 +468,4 @@ class RealESRGANUpscaler(BasePlugin): return upsampled def check_dep(self): - try: - import realesrgan - except ImportError: - return "RealESRGAN is not installed, please install it first. pip install realesrgan" + pass diff --git a/iopaint/runtime.py b/iopaint/runtime.py index 7199f83..c167fae 100644 --- a/iopaint/runtime.py +++ b/iopaint/runtime.py @@ -30,7 +30,6 @@ _CANDIDATES = [ "accelerate", "iopaint", "rembg", - "realesrgan", "gfpgan", ] # Check once at runtime