2021-11-15 08:22:34 +01:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
2021-12-16 14:29:32 +01:00
|
|
|
import argparse
|
2021-11-15 08:22:34 +01:00
|
|
|
import io
|
2021-12-16 14:29:32 +01:00
|
|
|
import multiprocessing
|
2021-11-15 08:22:34 +01:00
|
|
|
import os
|
2022-03-20 15:42:59 +01:00
|
|
|
import time
|
2022-03-24 19:33:13 +01:00
|
|
|
import imghdr
|
2021-11-27 13:37:37 +01:00
|
|
|
from typing import Union
|
2021-12-16 14:29:32 +01:00
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
import cv2
|
|
|
|
import torch
|
2022-04-09 02:12:37 +02:00
|
|
|
import numpy as np
|
2022-03-04 06:44:53 +01:00
|
|
|
from lama_cleaner.lama import LaMa
|
|
|
|
from lama_cleaner.ldm import LDM
|
|
|
|
|
2022-03-23 17:07:33 +01:00
|
|
|
from flaskwebgui import FlaskUI
|
|
|
|
|
2022-02-18 06:29:10 +01:00
|
|
|
try:
|
|
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
|
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
|
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
|
|
torch._C._jit_set_nvfuser_enabled(False)
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
from flask import Flask, request, send_file
|
|
|
|
from flask_cors import CORS
|
|
|
|
|
2022-02-09 11:01:19 +01:00
|
|
|
from lama_cleaner.helper import (
|
|
|
|
load_img,
|
|
|
|
norm_img,
|
|
|
|
numpy_to_bytes,
|
2022-03-12 13:43:47 +01:00
|
|
|
resize_max_size,
|
|
|
|
)
|
2021-11-15 20:11:46 +01:00
|
|
|
|
|
|
|
NUM_THREADS = str(multiprocessing.cpu_count())
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
2021-11-16 14:21:41 +01:00
|
|
|
if os.environ.get("CACHE_DIR"):
|
|
|
|
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
2021-11-15 08:22:34 +01:00
|
|
|
|
2022-04-09 01:23:33 +02:00
|
|
|
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
|
|
|
app.config["JSON_AS_ASCII"] = False
|
|
|
|
CORS(app)
|
|
|
|
|
|
|
|
model = None
|
|
|
|
device = None
|
2022-03-27 07:37:26 +02:00
|
|
|
input_image_path: str = None
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
|
2022-04-09 01:23:33 +02:00
|
|
|
def get_image_ext(img_bytes):
|
|
|
|
w = imghdr.what("", img_bytes)
|
|
|
|
if w is None:
|
|
|
|
w = "jpeg"
|
|
|
|
return w
|
|
|
|
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
@app.route("/inpaint", methods=["POST"])
|
|
|
|
def process():
|
|
|
|
input = request.files
|
2022-03-04 06:44:53 +01:00
|
|
|
# RGB
|
2022-04-09 01:23:33 +02:00
|
|
|
origin_image_bytes = input["image"].read()
|
|
|
|
|
2022-04-09 02:12:37 +02:00
|
|
|
image, alpha_channel = load_img(origin_image_bytes)
|
2021-11-27 13:37:37 +01:00
|
|
|
original_shape = image.shape
|
|
|
|
interpolation = cv2.INTER_CUBIC
|
|
|
|
|
|
|
|
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
|
|
|
if size_limit == "Original":
|
|
|
|
size_limit = max(image.shape)
|
|
|
|
else:
|
|
|
|
size_limit = int(size_limit)
|
|
|
|
|
|
|
|
print(f"Origin image shape: {original_shape}")
|
2022-04-09 01:23:33 +02:00
|
|
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
2021-11-27 13:37:37 +01:00
|
|
|
print(f"Resized image shape: {image.shape}")
|
|
|
|
image = norm_img(image)
|
|
|
|
|
2022-04-09 02:12:37 +02:00
|
|
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
2022-04-09 01:23:33 +02:00
|
|
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
2021-11-27 13:37:37 +01:00
|
|
|
mask = norm_img(mask)
|
|
|
|
|
2022-03-20 15:42:59 +01:00
|
|
|
start = time.time()
|
2022-03-04 06:44:53 +01:00
|
|
|
res_np_img = model(image, mask)
|
2022-03-20 15:42:59 +01:00
|
|
|
print(f"process time: {(time.time() - start) * 1000}ms")
|
2021-11-27 13:37:37 +01:00
|
|
|
|
2022-03-12 13:43:47 +01:00
|
|
|
torch.cuda.empty_cache()
|
2022-04-09 02:12:37 +02:00
|
|
|
if alpha_channel is not None:
|
|
|
|
res_np_img = np.concatenate(
|
|
|
|
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
|
|
|
)
|
2022-03-12 13:43:47 +01:00
|
|
|
|
2022-04-09 01:23:33 +02:00
|
|
|
ext = get_image_ext(origin_image_bytes)
|
2021-11-15 08:22:34 +01:00
|
|
|
return send_file(
|
2022-04-09 01:23:33 +02:00
|
|
|
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
|
|
|
mimetype=f"image/{ext}",
|
2021-11-15 08:22:34 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/")
|
|
|
|
def index():
|
|
|
|
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
|
|
|
|
|
|
|
|
2022-04-09 01:23:33 +02:00
|
|
|
@app.route("/inputimage")
|
2022-03-24 19:33:13 +01:00
|
|
|
def set_input_photo():
|
2022-03-27 07:37:26 +02:00
|
|
|
if input_image_path:
|
2022-04-09 01:23:33 +02:00
|
|
|
with open(input_image_path, "rb") as f:
|
2022-03-27 07:37:26 +02:00
|
|
|
image_in_bytes = f.read()
|
2022-04-09 01:23:33 +02:00
|
|
|
return send_file(
|
|
|
|
io.BytesIO(image_in_bytes),
|
|
|
|
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
|
|
|
)
|
2022-03-24 19:33:13 +01:00
|
|
|
else:
|
2022-04-09 01:23:33 +02:00
|
|
|
return "No Input Image"
|
2022-03-24 19:33:13 +01:00
|
|
|
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
def get_args_parser():
|
|
|
|
parser = argparse.ArgumentParser()
|
2022-03-24 19:33:13 +01:00
|
|
|
parser.add_argument(
|
2022-04-09 01:23:33 +02:00
|
|
|
"--input", type=str, help="Path to image you want to load by default"
|
|
|
|
)
|
2021-11-15 08:22:34 +01:00
|
|
|
parser.add_argument("--port", default=8080, type=int)
|
2022-03-04 06:44:53 +01:00
|
|
|
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
2022-04-09 01:23:33 +02:00
|
|
|
parser.add_argument(
|
|
|
|
"--crop-trigger-size",
|
|
|
|
default=[2042, 2042],
|
|
|
|
nargs=2,
|
|
|
|
type=int,
|
|
|
|
help="If image size large then crop-trigger-size, "
|
2022-04-09 02:12:37 +02:00
|
|
|
"crop each area from original image to do inference."
|
|
|
|
"Mainly for performance and memory reasons"
|
|
|
|
"Only for lama",
|
2022-04-09 01:23:33 +02:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--crop-margin",
|
|
|
|
type=int,
|
|
|
|
default=256,
|
|
|
|
help="Margin around bounding box of painted stroke when crop mode triggered",
|
|
|
|
)
|
2022-03-12 13:43:47 +01:00
|
|
|
parser.add_argument(
|
|
|
|
"--ldm-steps",
|
|
|
|
default=50,
|
|
|
|
type=int,
|
|
|
|
help="Steps for DDIM sampling process."
|
2022-04-09 02:12:37 +02:00
|
|
|
"The larger the value, the better the result, but it will be more time-consuming",
|
2022-03-12 13:43:47 +01:00
|
|
|
)
|
2021-11-15 08:22:34 +01:00
|
|
|
parser.add_argument("--device", default="cuda", type=str)
|
2022-04-09 01:23:33 +02:00
|
|
|
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
|
|
|
parser.add_argument(
|
|
|
|
"--gui-size",
|
|
|
|
default=[1600, 1000],
|
|
|
|
nargs=2,
|
|
|
|
type=int,
|
|
|
|
help="Set window size for GUI",
|
|
|
|
)
|
2021-11-27 13:37:37 +01:00
|
|
|
parser.add_argument("--debug", action="store_true")
|
2022-03-27 07:37:26 +02:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.input is not None:
|
|
|
|
if not os.path.exists(args.input):
|
|
|
|
parser.error(f"invalid --input: {args.input} not exists")
|
|
|
|
if imghdr.what(args.input) is None:
|
|
|
|
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
|
|
|
|
|
|
|
return args
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
global model
|
|
|
|
global device
|
2022-03-27 07:37:26 +02:00
|
|
|
global input_image_path
|
2022-03-24 19:33:13 +01:00
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
args = get_args_parser()
|
|
|
|
device = torch.device(args.device)
|
2022-03-27 07:37:26 +02:00
|
|
|
input_image_path = args.input
|
2022-03-24 19:33:13 +01:00
|
|
|
|
2022-03-04 06:44:53 +01:00
|
|
|
if args.model == "lama":
|
2022-04-09 01:23:33 +02:00
|
|
|
model = LaMa(
|
|
|
|
crop_trigger_size=args.crop_trigger_size,
|
|
|
|
crop_margin=args.crop_margin,
|
|
|
|
device=device,
|
|
|
|
)
|
2022-03-04 06:44:53 +01:00
|
|
|
elif args.model == "ldm":
|
|
|
|
model = LDM(device, steps=args.ldm_steps)
|
2021-12-05 14:32:18 +01:00
|
|
|
else:
|
2022-03-04 06:44:53 +01:00
|
|
|
raise NotImplementedError(f"Not supported model: {args.model}")
|
2021-12-05 14:32:18 +01:00
|
|
|
|
2022-03-23 17:07:33 +01:00
|
|
|
if args.gui:
|
|
|
|
app_width, app_height = args.gui_size
|
|
|
|
ui = FlaskUI(app, width=app_width, height=app_height)
|
2022-03-25 02:06:03 +01:00
|
|
|
ui.run()
|
2022-03-23 17:07:33 +01:00
|
|
|
else:
|
|
|
|
app.run(host="127.0.0.1", port=args.port, debug=args.debug)
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|