remove realesrgan dep

This commit is contained in:
Qing 2024-08-12 10:34:54 +08:00
parent 2f833029aa
commit ffdf5e06e1
7 changed files with 805 additions and 9 deletions

View File

@ -8,5 +8,4 @@ def install(package):
def install_plugins_package(): def install_plugins_package():
install("rembg") install("rembg")
install("realesrgan")
install("gfpgan") install("gfpgan")

View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2018-2022 BasicSR Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,22 @@
"""
Adapted from https://github.com/XPixelGroup/BasicSR
License: Apache-2.0
As of Feb 2024, `basicsr` appears to be unmaintained. It imports a function from `torchvision` that is removed in
`torchvision` 0.17. Here is the deprecation warning:
UserWarning: The torchvision.transforms.functional_tensor module is deprecated in 0.15 and will be **removed in
0.17**. Please don't rely on it. You probably just need to use APIs in torchvision.transforms.functional or in
torchvision.transforms.v2.functional.
As a result, a dependency on `basicsr` means we cannot keep our `torchvision` dependency up to date.
Because we only rely on a single class `RRDBNet` from `basicsr`, we've copied the relevant code here and removed the
dependency on `basicsr`.
The code is almost unchanged, only a few type annotations have been added. The license is also copied.
Copy From InvokeAI
"""
from .rrdbnet_arch import RRDBNet

View File

@ -0,0 +1,80 @@
from typing import Type, List, Union
import torch
from torch import nn as nn
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
@torch.no_grad()
def default_init_weights(
module_list: Union[List[nn.Module], nn.Module],
scale: float = 1,
bias_fill: float = 0,
**kwargs,
) -> None:
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(
basic_block: Type[nn.Module], num_basic_block: int, **kwarg
) -> nn.Sequential:
"""Make layers by stacking the same blocks.
Args:
basic_block (Type[nn.Module]): nn.Module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
# TODO: may write a cpp file
def pixel_unshuffle(x: torch.Tensor, scale: int) -> torch.Tensor:
"""Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)

View File

@ -0,0 +1,133 @@
import torch
from torch import nn as nn
from torch.nn import functional as F
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat: int = 64, num_grow_ch: int = 32) -> None:
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights(
[self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Empirically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat: int, num_grow_ch: int = 32) -> None:
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Empirically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(
self,
num_in_ch: int,
num_out_ch: int,
scale: int = 4,
num_feat: int = 64,
num_block: int = 23,
num_grow_ch: int = 32,
) -> None:
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(
RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch
)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
)
feat = self.lrelu(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
)
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out

View File

@ -1,6 +1,10 @@
import math
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from torch import nn
import torch.nn.functional as F
from loguru import logger from loguru import logger
from iopaint.helper import download_model from iopaint.helper import download_model
@ -8,6 +12,369 @@ from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest, RealESRGANModel from iopaint.schema import RunPluginRequest, RealESRGANModel
class RealESRGANer:
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(
self,
scale,
model_path,
dni_weight=None,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None,
):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
if gpu_id:
self.device = (
torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
else:
self.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
if isinstance(model_path, list):
# dni
assert len(model_path) == len(
dni_weight
), "model_path and dni_weight should have the save length."
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
else:
# if the model_path starts with https, it will first download models to the folder: weights
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
"""Deep network interpolation.
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
"""
net_a = torch.load(net_a, map_location=torch.device(loc))
net_b = torch.load(net_b, map_location=torch.device(loc))
for k, v_a in net_a[key].items():
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
return net_a
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if h % self.mod_scale != 0:
self.mod_pad_h = self.mod_scale - h % self.mod_scale
if w % self.mod_scale != 0:
self.mod_pad_w = self.mod_scale - w % self.mod_scale
self.img = F.pad(
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
)
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[
:,
:,
input_start_y_pad:input_end_y_pad,
input_start_x_pad:input_end_x_pad,
]
# upscale tile
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print("Error", error)
print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[
:, :, output_start_y:output_end_y, output_start_x:output_end_x
] = output_tile[
:,
:,
output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile,
]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.mod_pad_h * self.scale,
0 : w - self.mod_pad_w * self.scale,
]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.pre_pad * self.scale,
0 : w - self.pre_pad * self.scale,
]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print("\tInput is a 16-bit image")
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = "L"
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = "RGBA"
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == "realesrgan":
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = "RGB"
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == "L":
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == "RGBA":
if alpha_upsampler == "realesrgan":
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = (
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
)
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(
alpha,
(w * self.scale, h * self.scale),
interpolation=cv2.INTER_LINEAR,
)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output,
(
int(w_input * outscale),
int(h_input * outscale),
),
interpolation=cv2.INTER_LANCZOS4,
)
return output, img_mode
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
It is a compact network structure, which performs upsampling in the last layer and no convolution is
conducted on the HR feature space.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_conv (int): Number of convolution layers in the body network. Default: 16.
upscale (int): Upsampling factor. Default: 4.
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
"""
def __init__(
self,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=16,
upscale=4,
act_type="prelu",
):
super(SRVGGNetCompact, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.act_type = act_type
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# the first activation
if act_type == "relu":
activation = nn.ReLU(inplace=True)
elif act_type == "prelu":
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == "leakyrelu":
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the body structure
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
# activation
if act_type == "relu":
activation = nn.ReLU(inplace=True)
elif act_type == "prelu":
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == "leakyrelu":
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the last conv
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
# upsample
self.upsampler = nn.PixelShuffle(upscale)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
out += base
return out
class RealESRGANUpscaler(BasePlugin): class RealESRGANUpscaler(BasePlugin):
name = "RealESRGAN" name = "RealESRGAN"
support_gen_image = True support_gen_image = True
@ -20,9 +387,7 @@ class RealESRGANUpscaler(BasePlugin):
self._init_model(name) self._init_model(name)
def _init_model(self, name): def _init_model(self, name):
from basicsr.archs.rrdbnet_arch import RRDBNet from .basicsr import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
REAL_ESRGAN_MODELS = { REAL_ESRGAN_MODELS = {
RealESRGANModel.realesr_general_x4v3: { RealESRGANModel.realesr_general_x4v3: {
@ -103,7 +468,4 @@ class RealESRGANUpscaler(BasePlugin):
return upsampled return upsampled
def check_dep(self): def check_dep(self):
try: pass
import realesrgan
except ImportError:
return "RealESRGAN is not installed, please install it first. pip install realesrgan"

View File

@ -30,7 +30,6 @@ _CANDIDATES = [
"accelerate", "accelerate",
"iopaint", "iopaint",
"rembg", "rembg",
"realesrgan",
"gfpgan", "gfpgan",
] ]
# Check once at runtime # Check once at runtime