fix png alpha channel lose

This commit is contained in:
Sanster 2022-04-09 08:12:37 +08:00
parent caed45b520
commit 1b1aade067
2 changed files with 14 additions and 8 deletions

View File

@ -37,17 +37,19 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
def load_img(img_bytes, gray: bool = False): def load_img(img_bytes, gray: bool = False):
alpha_channel = None
nparr = np.frombuffer(img_bytes, np.uint8) nparr = np.frombuffer(img_bytes, np.uint8)
if gray: if gray:
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
else: else:
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
if len(np_img.shape) == 3 and np_img.shape[2] == 4: if len(np_img.shape) == 3 and np_img.shape[2] == 4:
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB) np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
else: else:
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
return np_img return np_img, alpha_channel
def norm_img(np_img): def norm_img(np_img):

10
main.py
View File

@ -10,7 +10,7 @@ from typing import Union
import cv2 import cv2
import torch import torch
import numpy as np
from lama_cleaner.lama import LaMa from lama_cleaner.lama import LaMa
from lama_cleaner.ldm import LDM from lama_cleaner.ldm import LDM
@ -68,7 +68,7 @@ def process():
# RGB # RGB
origin_image_bytes = input["image"].read() origin_image_bytes = input["image"].read()
image = load_img(origin_image_bytes) image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
@ -83,7 +83,7 @@ def process():
print(f"Resized image shape: {image.shape}") print(f"Resized image shape: {image.shape}")
image = norm_img(image) image = norm_img(image)
mask = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask) mask = norm_img(mask)
@ -92,6 +92,10 @@ def process():
print(f"process time: {(time.time() - start) * 1000}ms") print(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if alpha_channel is not None:
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes) ext = get_image_ext(origin_image_bytes)
return send_file( return send_file(