diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 9c6e986..c39f395 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -30,8 +30,8 @@ def ceil_modulo(x, mod): return (x // mod + 1) * mod -def numpy_to_bytes(image_numpy: np.ndarray) -> bytes: - data = cv2.imencode(".jpg", image_numpy)[1] +def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: + data = cv2.imencode(f".{ext}", image_numpy)[1] image_bytes = data.tobytes() return image_bytes @@ -92,7 +92,9 @@ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: """ height, width = mask.shape[1:] - _, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0) + _, thresh = cv2.threshold( + (mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0 + ) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) boxes = [] diff --git a/main.py b/main.py index 08b4e42..f694420 100644 --- a/main.py +++ b/main.py @@ -44,8 +44,7 @@ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS if os.environ.get("CACHE_DIR"): os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] -BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", - "./lama_cleaner/app/build") +BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build") app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app.config["JSON_AS_ASCII"] = False @@ -56,11 +55,20 @@ device = None input_image_path: str = None +def get_image_ext(img_bytes): + w = imghdr.what("", img_bytes) + if w is None: + w = "jpeg" + return w + + @app.route("/inpaint", methods=["POST"]) def process(): input = request.files # RGB - image = load_img(input["image"].read()) + origin_image_bytes = input["image"].read() + + image = load_img(origin_image_bytes) original_shape = image.shape interpolation = cv2.INTER_CUBIC @@ -71,14 +79,12 @@ def process(): size_limit = int(size_limit) print(f"Origin image shape: {original_shape}") - image = resize_max_size(image, size_limit=size_limit, - interpolation=interpolation) + image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) print(f"Resized image shape: {image.shape}") image = norm_img(image) mask = load_img(input["mask"].read(), gray=True) - mask = resize_max_size(mask, size_limit=size_limit, - interpolation=interpolation) + mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = norm_img(mask) start = time.time() @@ -87,11 +93,10 @@ def process(): torch.cuda.empty_cache() + ext = get_image_ext(origin_image_bytes) return send_file( - io.BytesIO(numpy_to_bytes(res_np_img)), - mimetype="image/jpeg", - as_attachment=True, - attachment_filename="result.jpeg", + io.BytesIO(numpy_to_bytes(res_np_img, ext)), + mimetype=f"image/{ext}", ) @@ -100,29 +105,42 @@ def index(): return send_file(os.path.join(BUILD_DIR, "index.html")) -@app.route('/inputimage') +@app.route("/inputimage") def set_input_photo(): if input_image_path: - with open(input_image_path, 'rb') as f: + with open(input_image_path, "rb") as f: image_in_bytes = f.read() - return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg') + return send_file( + io.BytesIO(image_in_bytes), + mimetype=f"image/{get_image_ext(image_in_bytes)}", + ) else: - return 'No Input Image' + return "No Input Image" def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "--input", type=str, help="Path to image you want to load by default") + "--input", type=str, help="Path to image you want to load by default" + ) parser.add_argument("--port", default=8080, type=int) parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) - parser.add_argument("--crop-trigger-size", default=[2042, 2042], nargs=2, type=int, - help="If image size large then crop-trigger-size, " - "crop each area from original image to do inference." - "Mainly for performance and memory reasons" - "Only for lama") - parser.add_argument("--crop-margin", type=int, default=256, - help="Margin around bounding box of painted stroke when crop mode triggered") + parser.add_argument( + "--crop-trigger-size", + default=[2042, 2042], + nargs=2, + type=int, + help="If image size large then crop-trigger-size, " + "crop each area from original image to do inference." + "Mainly for performance and memory reasons" + "Only for lama", + ) + parser.add_argument( + "--crop-margin", + type=int, + default=256, + help="Margin around bounding box of painted stroke when crop mode triggered", + ) parser.add_argument( "--ldm-steps", default=50, @@ -131,10 +149,14 @@ def get_args_parser(): "The larger the value, the better the result, but it will be more time-consuming", ) parser.add_argument("--device", default="cuda", type=str) - parser.add_argument("--gui", action="store_true", - help="Launch as desktop app") - parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int, - help="Set window size for GUI") + parser.add_argument("--gui", action="store_true", help="Launch as desktop app") + parser.add_argument( + "--gui-size", + default=[1600, 1000], + nargs=2, + type=int, + help="Set window size for GUI", + ) parser.add_argument("--debug", action="store_true") args = parser.parse_args() @@ -157,8 +179,11 @@ def main(): input_image_path = args.input if args.model == "lama": - model = LaMa(crop_trigger_size=args.crop_trigger_size, - crop_margin=args.crop_margin, device=device) + model = LaMa( + crop_trigger_size=args.crop_trigger_size, + crop_margin=args.crop_margin, + device=device, + ) elif args.model == "ldm": model = LDM(device, steps=args.ldm_steps) else: