fix png alpha channel lose

This commit is contained in:
Sanster 2022-04-09 08:12:37 +08:00
parent caed45b520
commit 1b1aade067
2 changed files with 14 additions and 8 deletions

View File

@ -37,17 +37,19 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
def load_img(img_bytes, gray: bool = False): def load_img(img_bytes, gray: bool = False):
alpha_channel = None
nparr = np.frombuffer(img_bytes, np.uint8) nparr = np.frombuffer(img_bytes, np.uint8)
if gray: if gray:
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
else: else:
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
if len(np_img.shape) == 3 and np_img.shape[2] == 4: if len(np_img.shape) == 3 and np_img.shape[2] == 4:
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB) np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
else: else:
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
return np_img return np_img, alpha_channel
def norm_img(np_img): def norm_img(np_img):

18
main.py
View File

@ -10,7 +10,7 @@ from typing import Union
import cv2 import cv2
import torch import torch
import numpy as np
from lama_cleaner.lama import LaMa from lama_cleaner.lama import LaMa
from lama_cleaner.ldm import LDM from lama_cleaner.ldm import LDM
@ -68,7 +68,7 @@ def process():
# RGB # RGB
origin_image_bytes = input["image"].read() origin_image_bytes = input["image"].read()
image = load_img(origin_image_bytes) image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
@ -83,7 +83,7 @@ def process():
print(f"Resized image shape: {image.shape}") print(f"Resized image shape: {image.shape}")
image = norm_img(image) image = norm_img(image)
mask = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask) mask = norm_img(mask)
@ -92,6 +92,10 @@ def process():
print(f"process time: {(time.time() - start) * 1000}ms") print(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if alpha_channel is not None:
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes) ext = get_image_ext(origin_image_bytes)
return send_file( return send_file(
@ -131,9 +135,9 @@ def get_args_parser():
nargs=2, nargs=2,
type=int, type=int,
help="If image size large then crop-trigger-size, " help="If image size large then crop-trigger-size, "
"crop each area from original image to do inference." "crop each area from original image to do inference."
"Mainly for performance and memory reasons" "Mainly for performance and memory reasons"
"Only for lama", "Only for lama",
) )
parser.add_argument( parser.add_argument(
"--crop-margin", "--crop-margin",
@ -146,7 +150,7 @@ def get_args_parser():
default=50, default=50,
type=int, type=int,
help="Steps for DDIM sampling process." help="Steps for DDIM sampling process."
"The larger the value, the better the result, but it will be more time-consuming", "The larger the value, the better the result, but it will be more time-consuming",
) )
parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument("--gui", action="store_true", help="Launch as desktop app")