fix png alpha channel lose
This commit is contained in:
parent
caed45b520
commit
1b1aade067
@ -37,17 +37,19 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
def load_img(img_bytes, gray: bool = False):
|
def load_img(img_bytes, gray: bool = False):
|
||||||
|
alpha_channel = None
|
||||||
nparr = np.frombuffer(img_bytes, np.uint8)
|
nparr = np.frombuffer(img_bytes, np.uint8)
|
||||||
if gray:
|
if gray:
|
||||||
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
||||||
else:
|
else:
|
||||||
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
||||||
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
|
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
|
||||||
|
alpha_channel = np_img[:, :, -1]
|
||||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
|
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
|
||||||
else:
|
else:
|
||||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
return np_img
|
return np_img, alpha_channel
|
||||||
|
|
||||||
|
|
||||||
def norm_img(np_img):
|
def norm_img(np_img):
|
||||||
|
18
main.py
18
main.py
@ -10,7 +10,7 @@ from typing import Union
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from lama_cleaner.lama import LaMa
|
from lama_cleaner.lama import LaMa
|
||||||
from lama_cleaner.ldm import LDM
|
from lama_cleaner.ldm import LDM
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ def process():
|
|||||||
# RGB
|
# RGB
|
||||||
origin_image_bytes = input["image"].read()
|
origin_image_bytes = input["image"].read()
|
||||||
|
|
||||||
image = load_img(origin_image_bytes)
|
image, alpha_channel = load_img(origin_image_bytes)
|
||||||
original_shape = image.shape
|
original_shape = image.shape
|
||||||
interpolation = cv2.INTER_CUBIC
|
interpolation = cv2.INTER_CUBIC
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ def process():
|
|||||||
print(f"Resized image shape: {image.shape}")
|
print(f"Resized image shape: {image.shape}")
|
||||||
image = norm_img(image)
|
image = norm_img(image)
|
||||||
|
|
||||||
mask = load_img(input["mask"].read(), gray=True)
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
|
|
||||||
@ -92,6 +92,10 @@ def process():
|
|||||||
print(f"process time: {(time.time() - start) * 1000}ms")
|
print(f"process time: {(time.time() - start) * 1000}ms")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
if alpha_channel is not None:
|
||||||
|
res_np_img = np.concatenate(
|
||||||
|
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||||
|
)
|
||||||
|
|
||||||
ext = get_image_ext(origin_image_bytes)
|
ext = get_image_ext(origin_image_bytes)
|
||||||
return send_file(
|
return send_file(
|
||||||
@ -131,9 +135,9 @@ def get_args_parser():
|
|||||||
nargs=2,
|
nargs=2,
|
||||||
type=int,
|
type=int,
|
||||||
help="If image size large then crop-trigger-size, "
|
help="If image size large then crop-trigger-size, "
|
||||||
"crop each area from original image to do inference."
|
"crop each area from original image to do inference."
|
||||||
"Mainly for performance and memory reasons"
|
"Mainly for performance and memory reasons"
|
||||||
"Only for lama",
|
"Only for lama",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--crop-margin",
|
"--crop-margin",
|
||||||
@ -146,7 +150,7 @@ def get_args_parser():
|
|||||||
default=50,
|
default=50,
|
||||||
type=int,
|
type=int,
|
||||||
help="Steps for DDIM sampling process."
|
help="Steps for DDIM sampling process."
|
||||||
"The larger the value, the better the result, but it will be more time-consuming",
|
"The larger the value, the better the result, but it will be more time-consuming",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", default="cuda", type=str)
|
parser.add_argument("--device", default="cuda", type=str)
|
||||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||||
|
Loading…
Reference in New Issue
Block a user