remove realesrgan dep
This commit is contained in:
parent
2f833029aa
commit
ffdf5e06e1
@ -8,5 +8,4 @@ def install(package):
|
|||||||
|
|
||||||
def install_plugins_package():
|
def install_plugins_package():
|
||||||
install("rembg")
|
install("rembg")
|
||||||
install("realesrgan")
|
|
||||||
install("gfpgan")
|
install("gfpgan")
|
||||||
|
201
iopaint/plugins/basicsr/LICENSE
Normal file
201
iopaint/plugins/basicsr/LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2018-2022 BasicSR Authors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
22
iopaint/plugins/basicsr/__init__.py
Normal file
22
iopaint/plugins/basicsr/__init__.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
Adapted from https://github.com/XPixelGroup/BasicSR
|
||||||
|
License: Apache-2.0
|
||||||
|
|
||||||
|
As of Feb 2024, `basicsr` appears to be unmaintained. It imports a function from `torchvision` that is removed in
|
||||||
|
`torchvision` 0.17. Here is the deprecation warning:
|
||||||
|
|
||||||
|
UserWarning: The torchvision.transforms.functional_tensor module is deprecated in 0.15 and will be **removed in
|
||||||
|
0.17**. Please don't rely on it. You probably just need to use APIs in torchvision.transforms.functional or in
|
||||||
|
torchvision.transforms.v2.functional.
|
||||||
|
|
||||||
|
As a result, a dependency on `basicsr` means we cannot keep our `torchvision` dependency up to date.
|
||||||
|
|
||||||
|
Because we only rely on a single class `RRDBNet` from `basicsr`, we've copied the relevant code here and removed the
|
||||||
|
dependency on `basicsr`.
|
||||||
|
|
||||||
|
The code is almost unchanged, only a few type annotations have been added. The license is also copied.
|
||||||
|
|
||||||
|
Copy From InvokeAI
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .rrdbnet_arch import RRDBNet
|
80
iopaint/plugins/basicsr/arch_util.py
Normal file
80
iopaint/plugins/basicsr/arch_util.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from typing import Type, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import init as init
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def default_init_weights(
|
||||||
|
module_list: Union[List[nn.Module], nn.Module],
|
||||||
|
scale: float = 1,
|
||||||
|
bias_fill: float = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize network weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
||||||
|
scale (float): Scale initialized weights, especially for residual
|
||||||
|
blocks. Default: 1.
|
||||||
|
bias_fill (float): The value to fill bias. Default: 0
|
||||||
|
kwargs (dict): Other arguments for initialization function.
|
||||||
|
"""
|
||||||
|
if not isinstance(module_list, list):
|
||||||
|
module_list = [module_list]
|
||||||
|
for module in module_list:
|
||||||
|
for m in module.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, **kwargs)
|
||||||
|
m.weight.data *= scale
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.fill_(bias_fill)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.kaiming_normal_(m.weight, **kwargs)
|
||||||
|
m.weight.data *= scale
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.fill_(bias_fill)
|
||||||
|
elif isinstance(m, _BatchNorm):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.fill_(bias_fill)
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(
|
||||||
|
basic_block: Type[nn.Module], num_basic_block: int, **kwarg
|
||||||
|
) -> nn.Sequential:
|
||||||
|
"""Make layers by stacking the same blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
basic_block (Type[nn.Module]): nn.Module class for basic block.
|
||||||
|
num_basic_block (int): number of blocks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
for _ in range(num_basic_block):
|
||||||
|
layers.append(basic_block(**kwarg))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: may write a cpp file
|
||||||
|
def pixel_unshuffle(x: torch.Tensor, scale: int) -> torch.Tensor:
|
||||||
|
"""Pixel unshuffle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||||
|
scale (int): Downsample ratio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: the pixel unshuffled feature.
|
||||||
|
"""
|
||||||
|
b, c, hh, hw = x.size()
|
||||||
|
out_channel = c * (scale**2)
|
||||||
|
assert hh % scale == 0 and hw % scale == 0
|
||||||
|
h = hh // scale
|
||||||
|
w = hw // scale
|
||||||
|
x_view = x.view(b, c, h, scale, w, scale)
|
||||||
|
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
133
iopaint/plugins/basicsr/rrdbnet_arch.py
Normal file
133
iopaint/plugins/basicsr/rrdbnet_arch.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualDenseBlock(nn.Module):
|
||||||
|
"""Residual Dense Block.
|
||||||
|
|
||||||
|
Used in RRDB block in ESRGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_feat (int): Channel number of intermediate features.
|
||||||
|
num_grow_ch (int): Channels for each growth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_feat: int = 64, num_grow_ch: int = 32) -> None:
|
||||||
|
super(ResidualDenseBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
||||||
|
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||||
|
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||||
|
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||||
|
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
default_init_weights(
|
||||||
|
[self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = self.lrelu(self.conv1(x))
|
||||||
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||||
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||||
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
# Empirically, we use 0.2 to scale the residual for better performance
|
||||||
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDB(nn.Module):
|
||||||
|
"""Residual in Residual Dense Block.
|
||||||
|
|
||||||
|
Used in RRDB-Net in ESRGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_feat (int): Channel number of intermediate features.
|
||||||
|
num_grow_ch (int): Channels for each growth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_feat: int, num_grow_ch: int = 32) -> None:
|
||||||
|
super(RRDB, self).__init__()
|
||||||
|
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||||
|
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||||
|
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
out = self.rdb1(x)
|
||||||
|
out = self.rdb2(out)
|
||||||
|
out = self.rdb3(out)
|
||||||
|
# Empirically, we use 0.2 to scale the residual for better performance
|
||||||
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBNet(nn.Module):
|
||||||
|
"""Networks consisting of Residual in Residual Dense Block, which is used
|
||||||
|
in ESRGAN.
|
||||||
|
|
||||||
|
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
||||||
|
|
||||||
|
We extend ESRGAN for scale x2 and scale x1.
|
||||||
|
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
||||||
|
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
||||||
|
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_in_ch (int): Channel number of inputs.
|
||||||
|
num_out_ch (int): Channel number of outputs.
|
||||||
|
num_feat (int): Channel number of intermediate features.
|
||||||
|
Default: 64
|
||||||
|
num_block (int): Block number in the trunk network. Defaults: 23
|
||||||
|
num_grow_ch (int): Channels for each growth. Default: 32.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_in_ch: int,
|
||||||
|
num_out_ch: int,
|
||||||
|
scale: int = 4,
|
||||||
|
num_feat: int = 64,
|
||||||
|
num_block: int = 23,
|
||||||
|
num_grow_ch: int = 32,
|
||||||
|
) -> None:
|
||||||
|
super(RRDBNet, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
if scale == 2:
|
||||||
|
num_in_ch = num_in_ch * 4
|
||||||
|
elif scale == 1:
|
||||||
|
num_in_ch = num_in_ch * 16
|
||||||
|
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||||
|
self.body = make_layer(
|
||||||
|
RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch
|
||||||
|
)
|
||||||
|
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||||
|
# upsample
|
||||||
|
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||||
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||||
|
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||||
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.scale == 2:
|
||||||
|
feat = pixel_unshuffle(x, scale=2)
|
||||||
|
elif self.scale == 1:
|
||||||
|
feat = pixel_unshuffle(x, scale=4)
|
||||||
|
else:
|
||||||
|
feat = x
|
||||||
|
feat = self.conv_first(feat)
|
||||||
|
body_feat = self.conv_body(self.body(feat))
|
||||||
|
feat = feat + body_feat
|
||||||
|
# upsample
|
||||||
|
feat = self.lrelu(
|
||||||
|
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
||||||
|
)
|
||||||
|
feat = self.lrelu(
|
||||||
|
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
||||||
|
)
|
||||||
|
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||||
|
return out
|
@ -1,6 +1,10 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from iopaint.helper import download_model
|
from iopaint.helper import download_model
|
||||||
@ -8,6 +12,369 @@ from iopaint.plugins.base_plugin import BasePlugin
|
|||||||
from iopaint.schema import RunPluginRequest, RealESRGANModel
|
from iopaint.schema import RunPluginRequest, RealESRGANModel
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANer:
|
||||||
|
"""A helper class for upsampling images with RealESRGAN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||||
|
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||||
|
model (nn.Module): The defined network. Default: None.
|
||||||
|
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||||
|
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||||
|
0 denotes for do not use tile. Default: 0.
|
||||||
|
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||||
|
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||||
|
half (float): Whether to use half precision during inference. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scale,
|
||||||
|
model_path,
|
||||||
|
dni_weight=None,
|
||||||
|
model=None,
|
||||||
|
tile=0,
|
||||||
|
tile_pad=10,
|
||||||
|
pre_pad=10,
|
||||||
|
half=False,
|
||||||
|
device=None,
|
||||||
|
gpu_id=None,
|
||||||
|
):
|
||||||
|
self.scale = scale
|
||||||
|
self.tile_size = tile
|
||||||
|
self.tile_pad = tile_pad
|
||||||
|
self.pre_pad = pre_pad
|
||||||
|
self.mod_scale = None
|
||||||
|
self.half = half
|
||||||
|
|
||||||
|
# initialize model
|
||||||
|
if gpu_id:
|
||||||
|
self.device = (
|
||||||
|
torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||||
|
if device is None
|
||||||
|
else device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.device = (
|
||||||
|
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
if device is None
|
||||||
|
else device
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(model_path, list):
|
||||||
|
# dni
|
||||||
|
assert len(model_path) == len(
|
||||||
|
dni_weight
|
||||||
|
), "model_path and dni_weight should have the save length."
|
||||||
|
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
||||||
|
else:
|
||||||
|
# if the model_path starts with https, it will first download models to the folder: weights
|
||||||
|
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||||
|
|
||||||
|
# prefer to use params_ema
|
||||||
|
if "params_ema" in loadnet:
|
||||||
|
keyname = "params_ema"
|
||||||
|
else:
|
||||||
|
keyname = "params"
|
||||||
|
model.load_state_dict(loadnet[keyname], strict=True)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
self.model = model.to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.model = self.model.half()
|
||||||
|
|
||||||
|
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
|
||||||
|
"""Deep network interpolation.
|
||||||
|
|
||||||
|
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
|
||||||
|
"""
|
||||||
|
net_a = torch.load(net_a, map_location=torch.device(loc))
|
||||||
|
net_b = torch.load(net_b, map_location=torch.device(loc))
|
||||||
|
for k, v_a in net_a[key].items():
|
||||||
|
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
|
||||||
|
return net_a
|
||||||
|
|
||||||
|
def pre_process(self, img):
|
||||||
|
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
|
||||||
|
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||||
|
self.img = img.unsqueeze(0).to(self.device)
|
||||||
|
if self.half:
|
||||||
|
self.img = self.img.half()
|
||||||
|
|
||||||
|
# pre_pad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
|
||||||
|
# mod pad for divisible borders
|
||||||
|
if self.scale == 2:
|
||||||
|
self.mod_scale = 2
|
||||||
|
elif self.scale == 1:
|
||||||
|
self.mod_scale = 4
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||||
|
_, _, h, w = self.img.size()
|
||||||
|
if h % self.mod_scale != 0:
|
||||||
|
self.mod_pad_h = self.mod_scale - h % self.mod_scale
|
||||||
|
if w % self.mod_scale != 0:
|
||||||
|
self.mod_pad_w = self.mod_scale - w % self.mod_scale
|
||||||
|
self.img = F.pad(
|
||||||
|
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
# model inference
|
||||||
|
self.output = self.model(self.img)
|
||||||
|
|
||||||
|
def tile_process(self):
|
||||||
|
"""It will first crop input images to tiles, and then process each tile.
|
||||||
|
Finally, all the processed tiles are merged into one images.
|
||||||
|
|
||||||
|
Modified from: https://github.com/ata4/esrgan-launcher
|
||||||
|
"""
|
||||||
|
batch, channel, height, width = self.img.shape
|
||||||
|
output_height = height * self.scale
|
||||||
|
output_width = width * self.scale
|
||||||
|
output_shape = (batch, channel, output_height, output_width)
|
||||||
|
|
||||||
|
# start with black image
|
||||||
|
self.output = self.img.new_zeros(output_shape)
|
||||||
|
tiles_x = math.ceil(width / self.tile_size)
|
||||||
|
tiles_y = math.ceil(height / self.tile_size)
|
||||||
|
|
||||||
|
# loop over all tiles
|
||||||
|
for y in range(tiles_y):
|
||||||
|
for x in range(tiles_x):
|
||||||
|
# extract tile from input image
|
||||||
|
ofs_x = x * self.tile_size
|
||||||
|
ofs_y = y * self.tile_size
|
||||||
|
# input tile area on total image
|
||||||
|
input_start_x = ofs_x
|
||||||
|
input_end_x = min(ofs_x + self.tile_size, width)
|
||||||
|
input_start_y = ofs_y
|
||||||
|
input_end_y = min(ofs_y + self.tile_size, height)
|
||||||
|
|
||||||
|
# input tile area on total image with padding
|
||||||
|
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||||
|
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||||
|
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||||
|
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||||
|
|
||||||
|
# input tile dimensions
|
||||||
|
input_tile_width = input_end_x - input_start_x
|
||||||
|
input_tile_height = input_end_y - input_start_y
|
||||||
|
tile_idx = y * tiles_x + x + 1
|
||||||
|
input_tile = self.img[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
input_start_y_pad:input_end_y_pad,
|
||||||
|
input_start_x_pad:input_end_x_pad,
|
||||||
|
]
|
||||||
|
|
||||||
|
# upscale tile
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tile = self.model(input_tile)
|
||||||
|
except RuntimeError as error:
|
||||||
|
print("Error", error)
|
||||||
|
print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
|
||||||
|
|
||||||
|
# output tile area on total image
|
||||||
|
output_start_x = input_start_x * self.scale
|
||||||
|
output_end_x = input_end_x * self.scale
|
||||||
|
output_start_y = input_start_y * self.scale
|
||||||
|
output_end_y = input_end_y * self.scale
|
||||||
|
|
||||||
|
# output tile area without padding
|
||||||
|
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||||
|
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||||
|
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||||
|
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||||
|
|
||||||
|
# put tile into output image
|
||||||
|
self.output[
|
||||||
|
:, :, output_start_y:output_end_y, output_start_x:output_end_x
|
||||||
|
] = output_tile[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
output_start_y_tile:output_end_y_tile,
|
||||||
|
output_start_x_tile:output_end_x_tile,
|
||||||
|
]
|
||||||
|
|
||||||
|
def post_process(self):
|
||||||
|
# remove extra pad
|
||||||
|
if self.mod_scale is not None:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
0 : h - self.mod_pad_h * self.scale,
|
||||||
|
0 : w - self.mod_pad_w * self.scale,
|
||||||
|
]
|
||||||
|
# remove prepad
|
||||||
|
if self.pre_pad != 0:
|
||||||
|
_, _, h, w = self.output.size()
|
||||||
|
self.output = self.output[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
0 : h - self.pre_pad * self.scale,
|
||||||
|
0 : w - self.pre_pad * self.scale,
|
||||||
|
]
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
|
||||||
|
h_input, w_input = img.shape[0:2]
|
||||||
|
# img: numpy
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
if np.max(img) > 256: # 16-bit image
|
||||||
|
max_range = 65535
|
||||||
|
print("\tInput is a 16-bit image")
|
||||||
|
else:
|
||||||
|
max_range = 255
|
||||||
|
img = img / max_range
|
||||||
|
if len(img.shape) == 2: # gray image
|
||||||
|
img_mode = "L"
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||||
|
img_mode = "RGBA"
|
||||||
|
alpha = img[:, :, 3]
|
||||||
|
img = img[:, :, 0:3]
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
if alpha_upsampler == "realesrgan":
|
||||||
|
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||||
|
else:
|
||||||
|
img_mode = "RGB"
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# ------------------- process image (without the alpha channel) ------------------- #
|
||||||
|
self.pre_process(img)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_img = self.post_process()
|
||||||
|
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
if img_mode == "L":
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# ------------------- process the alpha channel if necessary ------------------- #
|
||||||
|
if img_mode == "RGBA":
|
||||||
|
if alpha_upsampler == "realesrgan":
|
||||||
|
self.pre_process(alpha)
|
||||||
|
if self.tile_size > 0:
|
||||||
|
self.tile_process()
|
||||||
|
else:
|
||||||
|
self.process()
|
||||||
|
output_alpha = self.post_process()
|
||||||
|
output_alpha = (
|
||||||
|
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
)
|
||||||
|
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||||
|
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||||
|
else: # use the cv2 resize for alpha channel
|
||||||
|
h, w = alpha.shape[0:2]
|
||||||
|
output_alpha = cv2.resize(
|
||||||
|
alpha,
|
||||||
|
(w * self.scale, h * self.scale),
|
||||||
|
interpolation=cv2.INTER_LINEAR,
|
||||||
|
)
|
||||||
|
|
||||||
|
# merge the alpha channel
|
||||||
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||||
|
output_img[:, :, 3] = output_alpha
|
||||||
|
|
||||||
|
# ------------------------------ return ------------------------------ #
|
||||||
|
if max_range == 65535: # 16-bit image
|
||||||
|
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||||
|
else:
|
||||||
|
output = (output_img * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
|
if outscale is not None and outscale != float(self.scale):
|
||||||
|
output = cv2.resize(
|
||||||
|
output,
|
||||||
|
(
|
||||||
|
int(w_input * outscale),
|
||||||
|
int(h_input * outscale),
|
||||||
|
),
|
||||||
|
interpolation=cv2.INTER_LANCZOS4,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output, img_mode
|
||||||
|
|
||||||
|
|
||||||
|
class SRVGGNetCompact(nn.Module):
|
||||||
|
"""A compact VGG-style network structure for super-resolution.
|
||||||
|
|
||||||
|
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
||||||
|
conducted on the HR feature space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_in_ch (int): Channel number of inputs. Default: 3.
|
||||||
|
num_out_ch (int): Channel number of outputs. Default: 3.
|
||||||
|
num_feat (int): Channel number of intermediate features. Default: 64.
|
||||||
|
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
||||||
|
upscale (int): Upsampling factor. Default: 4.
|
||||||
|
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_conv=16,
|
||||||
|
upscale=4,
|
||||||
|
act_type="prelu",
|
||||||
|
):
|
||||||
|
super(SRVGGNetCompact, self).__init__()
|
||||||
|
self.num_in_ch = num_in_ch
|
||||||
|
self.num_out_ch = num_out_ch
|
||||||
|
self.num_feat = num_feat
|
||||||
|
self.num_conv = num_conv
|
||||||
|
self.upscale = upscale
|
||||||
|
self.act_type = act_type
|
||||||
|
|
||||||
|
self.body = nn.ModuleList()
|
||||||
|
# the first conv
|
||||||
|
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||||
|
# the first activation
|
||||||
|
if act_type == "relu":
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == "prelu":
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == "leakyrelu":
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the body structure
|
||||||
|
for _ in range(num_conv):
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||||
|
# activation
|
||||||
|
if act_type == "relu":
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == "prelu":
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == "leakyrelu":
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the last conv
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||||
|
# upsample
|
||||||
|
self.upsampler = nn.PixelShuffle(upscale)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x
|
||||||
|
for i in range(0, len(self.body)):
|
||||||
|
out = self.body[i](out)
|
||||||
|
|
||||||
|
out = self.upsampler(out)
|
||||||
|
# add the nearest upsampled image, so that the network learns the residual
|
||||||
|
base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
|
||||||
|
out += base
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class RealESRGANUpscaler(BasePlugin):
|
class RealESRGANUpscaler(BasePlugin):
|
||||||
name = "RealESRGAN"
|
name = "RealESRGAN"
|
||||||
support_gen_image = True
|
support_gen_image = True
|
||||||
@ -20,9 +387,7 @@ class RealESRGANUpscaler(BasePlugin):
|
|||||||
self._init_model(name)
|
self._init_model(name)
|
||||||
|
|
||||||
def _init_model(self, name):
|
def _init_model(self, name):
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from .basicsr import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
||||||
|
|
||||||
REAL_ESRGAN_MODELS = {
|
REAL_ESRGAN_MODELS = {
|
||||||
RealESRGANModel.realesr_general_x4v3: {
|
RealESRGANModel.realesr_general_x4v3: {
|
||||||
@ -103,7 +468,4 @@ class RealESRGANUpscaler(BasePlugin):
|
|||||||
return upsampled
|
return upsampled
|
||||||
|
|
||||||
def check_dep(self):
|
def check_dep(self):
|
||||||
try:
|
pass
|
||||||
import realesrgan
|
|
||||||
except ImportError:
|
|
||||||
return "RealESRGAN is not installed, please install it first. pip install realesrgan"
|
|
||||||
|
@ -30,7 +30,6 @@ _CANDIDATES = [
|
|||||||
"accelerate",
|
"accelerate",
|
||||||
"iopaint",
|
"iopaint",
|
||||||
"rembg",
|
"rembg",
|
||||||
"realesrgan",
|
|
||||||
"gfpgan",
|
"gfpgan",
|
||||||
]
|
]
|
||||||
# Check once at runtime
|
# Check once at runtime
|
||||||
|
Loading…
Reference in New Issue
Block a user