diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index c39f395..101d894 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -37,17 +37,19 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: def load_img(img_bytes, gray: bool = False): + alpha_channel = None nparr = np.frombuffer(img_bytes, np.uint8) if gray: np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) else: np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) if len(np_img.shape) == 3 and np_img.shape[2] == 4: + alpha_channel = np_img[:, :, -1] np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB) else: np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) - return np_img + return np_img, alpha_channel def norm_img(np_img): diff --git a/main.py b/main.py index f694420..f6051b1 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from typing import Union import cv2 import torch - +import numpy as np from lama_cleaner.lama import LaMa from lama_cleaner.ldm import LDM @@ -68,7 +68,7 @@ def process(): # RGB origin_image_bytes = input["image"].read() - image = load_img(origin_image_bytes) + image, alpha_channel = load_img(origin_image_bytes) original_shape = image.shape interpolation = cv2.INTER_CUBIC @@ -83,7 +83,7 @@ def process(): print(f"Resized image shape: {image.shape}") image = norm_img(image) - mask = load_img(input["mask"].read(), gray=True) + mask, _ = load_img(input["mask"].read(), gray=True) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = norm_img(mask) @@ -92,6 +92,10 @@ def process(): print(f"process time: {(time.time() - start) * 1000}ms") torch.cuda.empty_cache() + if alpha_channel is not None: + res_np_img = np.concatenate( + (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 + ) ext = get_image_ext(origin_image_bytes) return send_file( @@ -131,9 +135,9 @@ def get_args_parser(): nargs=2, type=int, help="If image size large then crop-trigger-size, " - "crop each area from original image to do inference." - "Mainly for performance and memory reasons" - "Only for lama", + "crop each area from original image to do inference." + "Mainly for performance and memory reasons" + "Only for lama", ) parser.add_argument( "--crop-margin", @@ -146,7 +150,7 @@ def get_args_parser(): default=50, type=int, help="Steps for DDIM sampling process." - "The larger the value, the better the result, but it will be more time-consuming", + "The larger the value, the better the result, but it will be more time-consuming", ) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--gui", action="store_true", help="Launch as desktop app")