2021-11-15 08:22:34 +01:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
2021-12-16 14:29:32 +01:00
|
|
|
import argparse
|
2021-11-15 08:22:34 +01:00
|
|
|
import io
|
2021-12-16 14:29:32 +01:00
|
|
|
import multiprocessing
|
2021-11-15 08:22:34 +01:00
|
|
|
import os
|
|
|
|
import time
|
2021-11-27 13:37:37 +01:00
|
|
|
from distutils.util import strtobool
|
|
|
|
from typing import Union
|
2021-12-16 14:29:32 +01:00
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from flask import Flask, request, send_file
|
|
|
|
from flask_cors import CORS
|
|
|
|
|
2021-12-16 14:29:32 +01:00
|
|
|
from lama_cleaner.helper import (download_model, load_img, norm_img,
|
|
|
|
numpy_to_bytes, pad_img_to_modulo,
|
|
|
|
resize_max_size)
|
2021-11-15 20:11:46 +01:00
|
|
|
|
|
|
|
NUM_THREADS = str(multiprocessing.cpu_count())
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
|
|
|
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
2021-11-16 14:21:41 +01:00
|
|
|
if os.environ.get("CACHE_DIR"):
|
|
|
|
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
|
|
|
|
|
|
|
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
|
|
|
app.config["JSON_AS_ASCII"] = False
|
|
|
|
CORS(app)
|
|
|
|
|
|
|
|
model = None
|
|
|
|
device = None
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/inpaint", methods=["POST"])
|
|
|
|
def process():
|
|
|
|
input = request.files
|
|
|
|
image = load_img(input["image"].read())
|
2021-11-27 13:37:37 +01:00
|
|
|
original_shape = image.shape
|
|
|
|
interpolation = cv2.INTER_CUBIC
|
|
|
|
|
|
|
|
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
|
|
|
if size_limit == "Original":
|
|
|
|
size_limit = max(image.shape)
|
|
|
|
else:
|
|
|
|
size_limit = int(size_limit)
|
|
|
|
|
|
|
|
print(f"Origin image shape: {original_shape}")
|
|
|
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
|
|
|
print(f"Resized image shape: {image.shape}")
|
|
|
|
image = norm_img(image)
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
mask = load_img(input["mask"].read(), gray=True)
|
2021-11-27 13:37:37 +01:00
|
|
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
|
|
|
mask = norm_img(mask)
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
res_np_img = run(image, mask)
|
2021-11-27 13:37:37 +01:00
|
|
|
|
|
|
|
# resize to original size
|
|
|
|
res_np_img = cv2.resize(
|
|
|
|
res_np_img,
|
|
|
|
dsize=(original_shape[1], original_shape[0]),
|
|
|
|
interpolation=interpolation,
|
|
|
|
)
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
return send_file(
|
|
|
|
io.BytesIO(numpy_to_bytes(res_np_img)),
|
2021-11-27 13:37:37 +01:00
|
|
|
mimetype="image/jpeg",
|
2021-11-15 08:22:34 +01:00
|
|
|
as_attachment=True,
|
2021-11-27 13:37:37 +01:00
|
|
|
attachment_filename="result.jpeg",
|
2021-11-15 08:22:34 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/")
|
|
|
|
def index():
|
|
|
|
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
|
|
|
|
|
|
|
|
|
|
|
def run(image, mask):
|
|
|
|
"""
|
|
|
|
image: [C, H, W]
|
2021-11-27 13:37:37 +01:00
|
|
|
mask: [1, H, W]
|
|
|
|
return: BGR IMAGE
|
2021-11-15 08:22:34 +01:00
|
|
|
"""
|
|
|
|
origin_height, origin_width = image.shape[1:]
|
|
|
|
image = pad_img_to_modulo(image, mod=8)
|
|
|
|
mask = pad_img_to_modulo(mask, mod=8)
|
|
|
|
|
|
|
|
mask = (mask > 0) * 1
|
|
|
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
|
|
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
start = time.time()
|
2021-12-16 14:29:32 +01:00
|
|
|
with torch.no_grad():
|
|
|
|
inpainted_image = model(image, mask)
|
2021-11-15 08:22:34 +01:00
|
|
|
|
2021-11-27 13:37:37 +01:00
|
|
|
print(f"process time: {(time.time() - start)*1000}ms")
|
2021-11-15 08:22:34 +01:00
|
|
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
|
|
|
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
|
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
2021-11-27 13:37:37 +01:00
|
|
|
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
|
2021-11-15 08:22:34 +01:00
|
|
|
return cur_res
|
|
|
|
|
|
|
|
|
|
|
|
def get_args_parser():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--port", default=8080, type=int)
|
|
|
|
parser.add_argument("--device", default="cuda", type=str)
|
2021-11-27 13:37:37 +01:00
|
|
|
parser.add_argument("--debug", action="store_true")
|
2021-11-15 08:22:34 +01:00
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
global model
|
|
|
|
global device
|
|
|
|
args = get_args_parser()
|
|
|
|
device = torch.device(args.device)
|
2021-12-05 14:32:18 +01:00
|
|
|
|
|
|
|
if os.environ.get("LAMA_MODEL"):
|
|
|
|
model_path = os.environ.get("LAMA_MODEL")
|
|
|
|
if not os.path.exists(model_path):
|
|
|
|
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
|
|
|
|
else:
|
|
|
|
model_path = download_model()
|
|
|
|
|
2021-11-15 08:22:34 +01:00
|
|
|
model = torch.jit.load(model_path, map_location="cpu")
|
|
|
|
model = model.to(device)
|
2021-12-16 14:29:32 +01:00
|
|
|
model.eval()
|
2021-11-27 13:37:37 +01:00
|
|
|
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
2021-11-15 08:22:34 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|