463 lines
14 KiB
Python
463 lines
14 KiB
Python
import cv2
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from lama_cleaner.helper import load_model
|
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
|
from lama_cleaner.schema import RunPluginRequest
|
|
|
|
|
|
class REBNCONV(nn.Module):
|
|
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
|
super(REBNCONV, self).__init__()
|
|
|
|
self.conv_s1 = nn.Conv2d(
|
|
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
|
)
|
|
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
|
self.relu_s1 = nn.ReLU(inplace=True)
|
|
|
|
def forward(self, x):
|
|
hx = x
|
|
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
|
|
|
return xout
|
|
|
|
|
|
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
|
def _upsample_like(src, tar):
|
|
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
|
|
|
return src
|
|
|
|
|
|
### RSU-7 ###
|
|
class RSU7(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
|
super(RSU7, self).__init__()
|
|
|
|
self.in_ch = in_ch
|
|
self.mid_ch = mid_ch
|
|
self.out_ch = out_ch
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
|
|
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
|
|
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
|
|
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
|
|
hx = x
|
|
hxin = self.rebnconvin(hx)
|
|
|
|
hx1 = self.rebnconv1(hxin)
|
|
hx = self.pool1(hx1)
|
|
|
|
hx2 = self.rebnconv2(hx)
|
|
hx = self.pool2(hx2)
|
|
|
|
hx3 = self.rebnconv3(hx)
|
|
hx = self.pool3(hx3)
|
|
|
|
hx4 = self.rebnconv4(hx)
|
|
hx = self.pool4(hx4)
|
|
|
|
hx5 = self.rebnconv5(hx)
|
|
hx = self.pool5(hx5)
|
|
|
|
hx6 = self.rebnconv6(hx)
|
|
|
|
hx7 = self.rebnconv7(hx6)
|
|
|
|
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
|
hx6dup = _upsample_like(hx6d, hx5)
|
|
|
|
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
|
hx5dup = _upsample_like(hx5d, hx4)
|
|
|
|
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
|
hx4dup = _upsample_like(hx4d, hx3)
|
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
|
hx3dup = _upsample_like(hx3d, hx2)
|
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
hx2dup = _upsample_like(hx2d, hx1)
|
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
|
|
return hx1d + hxin
|
|
|
|
|
|
### RSU-6 ###
|
|
class RSU6(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
super(RSU6, self).__init__()
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
|
|
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
|
|
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
|
|
def forward(self, x):
|
|
hx = x
|
|
|
|
hxin = self.rebnconvin(hx)
|
|
|
|
hx1 = self.rebnconv1(hxin)
|
|
hx = self.pool1(hx1)
|
|
|
|
hx2 = self.rebnconv2(hx)
|
|
hx = self.pool2(hx2)
|
|
|
|
hx3 = self.rebnconv3(hx)
|
|
hx = self.pool3(hx3)
|
|
|
|
hx4 = self.rebnconv4(hx)
|
|
hx = self.pool4(hx4)
|
|
|
|
hx5 = self.rebnconv5(hx)
|
|
|
|
hx6 = self.rebnconv6(hx5)
|
|
|
|
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
|
hx5dup = _upsample_like(hx5d, hx4)
|
|
|
|
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
|
hx4dup = _upsample_like(hx4d, hx3)
|
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
|
hx3dup = _upsample_like(hx3d, hx2)
|
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
hx2dup = _upsample_like(hx2d, hx1)
|
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
|
|
return hx1d + hxin
|
|
|
|
|
|
### RSU-5 ###
|
|
class RSU5(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
super(RSU5, self).__init__()
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
|
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
|
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
|
|
def forward(self, x):
|
|
hx = x
|
|
|
|
hxin = self.rebnconvin(hx)
|
|
|
|
hx1 = self.rebnconv1(hxin)
|
|
hx = self.pool1(hx1)
|
|
|
|
hx2 = self.rebnconv2(hx)
|
|
hx = self.pool2(hx2)
|
|
|
|
hx3 = self.rebnconv3(hx)
|
|
hx = self.pool3(hx3)
|
|
|
|
hx4 = self.rebnconv4(hx)
|
|
|
|
hx5 = self.rebnconv5(hx4)
|
|
|
|
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
|
hx4dup = _upsample_like(hx4d, hx3)
|
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
|
hx3dup = _upsample_like(hx3d, hx2)
|
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
hx2dup = _upsample_like(hx2d, hx1)
|
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
|
|
return hx1d + hxin
|
|
|
|
|
|
### RSU-4 ###
|
|
class RSU4(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
super(RSU4, self).__init__()
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
|
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
|
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
|
|
def forward(self, x):
|
|
hx = x
|
|
|
|
hxin = self.rebnconvin(hx)
|
|
|
|
hx1 = self.rebnconv1(hxin)
|
|
hx = self.pool1(hx1)
|
|
|
|
hx2 = self.rebnconv2(hx)
|
|
hx = self.pool2(hx2)
|
|
|
|
hx3 = self.rebnconv3(hx)
|
|
|
|
hx4 = self.rebnconv4(hx3)
|
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
|
hx3dup = _upsample_like(hx3d, hx2)
|
|
|
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
hx2dup = _upsample_like(hx2d, hx1)
|
|
|
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
|
|
return hx1d + hxin
|
|
|
|
|
|
### RSU-4F ###
|
|
class RSU4F(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
super(RSU4F, self).__init__()
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
|
|
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
|
|
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
|
|
def forward(self, x):
|
|
hx = x
|
|
|
|
hxin = self.rebnconvin(hx)
|
|
|
|
hx1 = self.rebnconv1(hxin)
|
|
hx2 = self.rebnconv2(hx1)
|
|
hx3 = self.rebnconv3(hx2)
|
|
|
|
hx4 = self.rebnconv4(hx3)
|
|
|
|
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
|
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
|
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
|
|
|
return hx1d + hxin
|
|
|
|
|
|
class ISNetDIS(nn.Module):
|
|
def __init__(self, in_ch=3, out_ch=1):
|
|
super(ISNetDIS, self).__init__()
|
|
|
|
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
|
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.stage1 = RSU7(64, 32, 64)
|
|
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.stage2 = RSU6(64, 32, 128)
|
|
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.stage3 = RSU5(128, 64, 256)
|
|
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.stage4 = RSU4(256, 128, 512)
|
|
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.stage5 = RSU4F(512, 256, 512)
|
|
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
|
|
self.stage6 = RSU4F(512, 256, 512)
|
|
|
|
# decoder
|
|
self.stage5d = RSU4F(1024, 256, 512)
|
|
self.stage4d = RSU4(1024, 128, 256)
|
|
self.stage3d = RSU5(512, 64, 128)
|
|
self.stage2d = RSU6(256, 32, 64)
|
|
self.stage1d = RSU7(128, 16, 64)
|
|
|
|
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
hx = x
|
|
|
|
hxin = self.conv_in(hx)
|
|
hx = self.pool_in(hxin)
|
|
|
|
# stage 1
|
|
hx1 = self.stage1(hxin)
|
|
hx = self.pool12(hx1)
|
|
|
|
# stage 2
|
|
hx2 = self.stage2(hx)
|
|
hx = self.pool23(hx2)
|
|
|
|
# stage 3
|
|
hx3 = self.stage3(hx)
|
|
hx = self.pool34(hx3)
|
|
|
|
# stage 4
|
|
hx4 = self.stage4(hx)
|
|
hx = self.pool45(hx4)
|
|
|
|
# stage 5
|
|
hx5 = self.stage5(hx)
|
|
hx = self.pool56(hx5)
|
|
|
|
# stage 6
|
|
hx6 = self.stage6(hx)
|
|
hx6up = _upsample_like(hx6, hx5)
|
|
|
|
# -------------------- decoder --------------------
|
|
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
|
hx5dup = _upsample_like(hx5d, hx4)
|
|
|
|
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
|
hx4dup = _upsample_like(hx4d, hx3)
|
|
|
|
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
|
hx3dup = _upsample_like(hx3d, hx2)
|
|
|
|
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
|
hx2dup = _upsample_like(hx2d, hx1)
|
|
|
|
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
|
|
|
# side output
|
|
d1 = self.side1(hx1d)
|
|
d1 = _upsample_like(d1, x)
|
|
return d1.sigmoid()
|
|
|
|
|
|
# 从小到大
|
|
ANIME_SEG_MODELS = {
|
|
"url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth",
|
|
"md5": "5f25479076b73074730ab8de9e8f2051",
|
|
}
|
|
|
|
|
|
class AnimeSeg(BasePlugin):
|
|
# Model from: https://github.com/SkyTNT/anime-segmentation
|
|
name = "AnimeSeg"
|
|
support_gen_image = True
|
|
support_gen_mask = True
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.model = load_model(
|
|
ISNetDIS(),
|
|
ANIME_SEG_MODELS["url"],
|
|
"cpu",
|
|
ANIME_SEG_MODELS["md5"],
|
|
)
|
|
|
|
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
|
mask = self.forward(rgb_np_img)
|
|
mask = Image.fromarray(mask, mode="L")
|
|
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
|
|
empty = Image.new("RGBA", (w0, h0), 0)
|
|
img = Image.fromarray(rgb_np_img)
|
|
cutout = Image.composite(img, empty, mask)
|
|
return np.asarray(cutout)
|
|
|
|
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
|
return self.forward(rgb_np_img)
|
|
|
|
@torch.inference_mode()
|
|
def forward(self, rgb_np_img):
|
|
s = 1024
|
|
|
|
h0, w0 = h, w = rgb_np_img.shape[0], rgb_np_img.shape[1]
|
|
if h > w:
|
|
h, w = s, int(s * w / h)
|
|
else:
|
|
h, w = int(s * h / w), s
|
|
ph, pw = s - h, s - w
|
|
tmpImg = np.zeros([s, s, 3], dtype=np.float32)
|
|
tmpImg[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = (
|
|
cv2.resize(rgb_np_img, (w, h)) / 255
|
|
)
|
|
tmpImg = tmpImg.transpose((2, 0, 1))
|
|
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor)
|
|
mask = self.model(tmpImg)
|
|
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
|
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
|
|
return (mask * 255).astype("uint8")
|