IOPaint/main.py

222 lines
6.0 KiB
Python
Raw Normal View History

2021-11-15 08:22:34 +01:00
#!/usr/bin/env python3
2021-12-16 14:29:32 +01:00
import argparse
2021-11-15 08:22:34 +01:00
import io
2022-04-15 18:11:51 +02:00
import logging
2021-12-16 14:29:32 +01:00
import multiprocessing
2021-11-15 08:22:34 +01:00
import os
2022-03-20 15:42:59 +01:00
import time
import imghdr
2022-04-15 18:11:51 +02:00
from pathlib import Path
from typing import Union
2021-12-16 14:29:32 +01:00
2021-11-15 08:22:34 +01:00
import cv2
import torch
2022-04-09 02:12:37 +02:00
import numpy as np
2022-04-15 18:11:51 +02:00
from loguru import logger
2022-03-04 06:44:53 +01:00
2022-04-15 18:11:51 +02:00
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
2021-11-15 08:22:34 +01:00
from flask import Flask, request, send_file
from flask_cors import CORS
2022-02-09 11:01:19 +01:00
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
2021-11-15 20:11:46 +01:00
NUM_THREADS = str(multiprocessing.cpu_count())
2021-11-15 08:22:34 +01:00
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
2021-11-16 14:21:41 +01:00
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
2021-11-15 08:22:34 +01:00
2022-04-09 01:23:33 +02:00
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
2021-11-15 08:22:34 +01:00
2022-04-15 18:11:51 +02:00
class InterceptHandler(logging.Handler):
def emit(self, record):
logger_opt = logger.opt(depth=6, exception=record.exc_info)
logger_opt.log(record.levelno, record.getMessage())
2021-11-15 08:22:34 +01:00
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
2022-04-15 18:11:51 +02:00
app.logger.addHandler(InterceptHandler())
CORS(app, expose_headers=["Content-Disposition"])
2021-11-15 08:22:34 +01:00
2022-04-15 18:11:51 +02:00
model: ModelManager = None
2021-11-15 08:22:34 +01:00
device = None
2022-03-27 07:37:26 +02:00
input_image_path: str = None
2021-11-15 08:22:34 +01:00
2022-04-09 01:23:33 +02:00
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
2021-11-15 08:22:34 +01:00
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
2022-03-04 06:44:53 +01:00
# RGB
2022-04-09 01:23:33 +02:00
origin_image_bytes = input["image"].read()
2022-04-09 02:12:37 +02:00
image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
2022-04-15 18:11:51 +02:00
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
2022-04-15 18:11:51 +02:00
config = Config(
ldm_steps=form['ldmSteps'],
hd_strategy=form['hdStrategy'],
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
)
logger.info(f"Origin image shape: {original_shape}")
2022-04-09 01:23:33 +02:00
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
2022-04-15 18:11:51 +02:00
logger.info(f"Resized image shape: {image.shape}")
2022-04-09 02:12:37 +02:00
mask, _ = load_img(input["mask"].read(), gray=True)
2022-04-09 01:23:33 +02:00
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
2022-03-20 15:42:59 +01:00
start = time.time()
2022-04-15 18:11:51 +02:00
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
2022-04-09 02:12:37 +02:00
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
2022-04-09 02:12:37 +02:00
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
2022-04-09 01:23:33 +02:00
ext = get_image_ext(origin_image_bytes)
2021-11-15 08:22:34 +01:00
return send_file(
2022-04-09 01:23:33 +02:00
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
2021-11-15 08:22:34 +01:00
)
2022-04-17 17:31:12 +02:00
@app.route("/model")
def current_model():
return model.name, 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@app.route("/model", methods=["POST"])
2022-04-15 18:11:51 +02:00
def switch_model():
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
2021-11-15 08:22:34 +01:00
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
2022-04-09 01:23:33 +02:00
@app.route("/inputimage")
def set_input_photo():
2022-03-27 07:37:26 +02:00
if input_image_path:
2022-04-09 01:23:33 +02:00
with open(input_image_path, "rb") as f:
2022-03-27 07:37:26 +02:00
image_in_bytes = f.read()
2022-04-09 01:23:33 +02:00
return send_file(
2022-04-15 18:11:51 +02:00
input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
2022-04-09 01:23:33 +02:00
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
2022-04-09 01:23:33 +02:00
return "No Input Image"
2021-11-15 08:22:34 +01:00
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
2022-04-09 01:23:33 +02:00
"--input", type=str, help="Path to image you want to load by default"
)
2022-04-09 15:01:30 +02:00
parser.add_argument("--host", default="127.0.0.1")
2021-11-15 08:22:34 +01:00
parser.add_argument("--port", default=8080, type=int)
2022-03-04 06:44:53 +01:00
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
2021-11-15 08:22:34 +01:00
parser.add_argument("--device", default="cuda", type=str)
2022-04-09 01:23:33 +02:00
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
"--gui-size",
default=[1600, 1000],
nargs=2,
type=int,
help="Set window size for GUI",
)
parser.add_argument("--debug", action="store_true")
2022-03-27 07:37:26 +02:00
args = parser.parse_args()
if args.input is not None:
if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists")
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
return args
2021-11-15 08:22:34 +01:00
def main():
global model
global device
2022-03-27 07:37:26 +02:00
global input_image_path
2021-11-15 08:22:34 +01:00
args = get_args_parser()
device = torch.device(args.device)
2022-03-27 07:37:26 +02:00
input_image_path = args.input
2022-04-15 18:11:51 +02:00
model = ModelManager(name=args.model, device=device)
if args.gui:
app_width, app_height = args.gui_size
2022-04-15 18:11:51 +02:00
from flaskwebgui import FlaskUI
ui = FlaskUI(app, width=app_width, height=app_height)
2022-03-25 02:06:03 +01:00
ui.run()
else:
2022-04-09 15:01:30 +02:00
app.run(host=args.host, port=args.port, debug=args.debug)
2021-11-15 08:22:34 +01:00
if __name__ == "__main__":
main()